Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
9fed1f5d
"tests/git@developer.sourcefind.cn:OpenDAS/llama-factory.git" did not exist on "7ea81099235fd4ccf8d4b9ba202e76cce40b5cc8"
Commit
9fed1f5d
authored
Mar 23, 2024
by
zhuwenwen
Browse files
add bf16
parent
3f1166ab
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
9 deletions
+9
-9
csrc/attention/attention_dtypes.h
csrc/attention/attention_dtypes.h
+1
-1
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+4
-4
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+4
-4
No files found.
csrc/attention/attention_dtypes.h
View file @
9fed1f5d
...
@@ -3,5 +3,5 @@
...
@@ -3,5 +3,5 @@
#include "attention_generic.cuh"
#include "attention_generic.cuh"
#include "dtype_float16.cuh"
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_float32.cuh"
//
#include "dtype_bfloat16.cuh"
#include "dtype_bfloat16.cuh"
// #include "dtype_fp8_e5m2.cuh"
// #include "dtype_fp8_e5m2.cuh"
csrc/attention/attention_kernels.cu
View file @
9fed1f5d
...
@@ -734,8 +734,8 @@ void paged_attention_v1(
...
@@ -734,8 +734,8 @@ void paged_attention_v1(
CALL_V1_LAUNCHER_BLOCK_SIZE
(
float
,
float
,
false
);
CALL_V1_LAUNCHER_BLOCK_SIZE
(
float
,
float
,
false
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_V1_LAUNCHER_BLOCK_SIZE
(
uint16_t
,
uint16_t
,
false
);
CALL_V1_LAUNCHER_BLOCK_SIZE
(
uint16_t
,
uint16_t
,
false
);
//
} else if (query.dtype() == at::ScalarType::BFloat16) {
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
//
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
CALL_V1_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
,
__nv_bfloat16
,
false
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
}
}
...
@@ -927,8 +927,8 @@ void paged_attention_v2(
...
@@ -927,8 +927,8 @@ void paged_attention_v2(
CALL_V2_LAUNCHER_BLOCK_SIZE
(
float
,
float
,
false
);
CALL_V2_LAUNCHER_BLOCK_SIZE
(
float
,
float
,
false
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_V2_LAUNCHER_BLOCK_SIZE
(
uint16_t
,
uint16_t
,
false
);
CALL_V2_LAUNCHER_BLOCK_SIZE
(
uint16_t
,
uint16_t
,
false
);
//
} else if (query.dtype() == at::ScalarType::BFloat16) {
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
//
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
CALL_V2_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
,
__nv_bfloat16
,
false
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
}
}
...
...
csrc/cache_kernels.cu
View file @
9fed1f5d
...
@@ -13,10 +13,10 @@
...
@@ -13,10 +13,10 @@
#include <map>
#include <map>
#include <vector>
#include <vector>
//
#ifdef USE_ROCM
#ifdef USE_ROCM
//
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
//
typedef __hip_bfloat16 __nv_bfloat16;
typedef
__hip_bfloat16
__nv_bfloat16
;
//
#endif
#endif
void
swap_blocks
(
void
swap_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
src
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment