Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
317a7889
Commit
317a7889
authored
Mar 26, 2024
by
zhuwenwen
Browse files
merge v0.3.3
parents
7477e8f3
73b1705e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
39 deletions
+33
-39
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+14
-20
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+14
-14
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+5
-5
No files found.
csrc/attention/attention_kernels.cu
View file @
317a7889
...
...
@@ -736,19 +736,16 @@ void paged_attention_v1(
CALL_V1_LAUNCHER_BLOCK_SIZE
(
uint16_t
,
uint16_t
,
false
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_V1_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
,
__nv_bfloat16
,
false
);
}
else
if
(
kv_cache_dtype
==
"fp8_e5m2"
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_V1_LAUNCHER_BLOCK_SIZE
(
float
,
uint8_t
,
true
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_V1_LAUNCHER_BLOCK_SIZE
(
uint16_t
,
uint8_t
,
true
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_V1_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
,
uint8_t
,
true
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
}
// } else if (kv_cache_dtype == "fp8_e5m2") {
// if (query.dtype() == at::ScalarType::Float) {
// CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
// } else if (query.dtype() == at::ScalarType::Half) {
// CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
// } else if (query.dtype() == at::ScalarType::BFloat16) {
// CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
// } else {
// TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
// }
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type of kv cache: "
,
kv_cache_dtype
);
}
...
...
@@ -929,19 +926,16 @@ void paged_attention_v2(
CALL_V2_LAUNCHER_BLOCK_SIZE
(
uint16_t
,
uint16_t
,
false
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_V2_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
,
__nv_bfloat16
,
false
);
}
else
if
(
kv_cache_dtype
==
"fp8_e5m2"
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_V2_LAUNCHER_BLOCK_SIZE
(
float
,
uint8_t
,
true
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_V2_LAUNCHER_BLOCK_SIZE
(
uint16_t
,
uint8_t
,
true
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_V2_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
,
uint8_t
,
true
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
}
// } else if (kv_cache_dtype == "fp8_e5m2") {
// if (query.dtype() == at::ScalarType::Float) {
// CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
// } else if (query.dtype() == at::ScalarType::Half) {
// CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
// } else if (query.dtype() == at::ScalarType::BFloat16) {
// CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
// } else {
// TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
// }
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type of kv cache: "
,
kv_cache_dtype
);
}
...
...
csrc/cache_kernels.cu
View file @
317a7889
...
...
@@ -251,17 +251,17 @@ void reshape_and_cache(
CALL_RESHAPE_AND_CACHE
(
float
,
float
,
false
);
}
else
if
(
key
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_RESHAPE_AND_CACHE
(
uint16_t
,
uint16_t
,
false
);
// } else if (key.dtype() == at::ScalarType::BFloat16) {
// CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
}
else
if
(
key
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_RESHAPE_AND_CACHE
(
__nv_bfloat16
,
__nv_bfloat16
,
false
);
}
}
else
if
(
kv_cache_dtype
==
"fp8_e5m2"
)
{
if
(
key
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_RESHAPE_AND_CACHE
(
float
,
uint8_t
,
true
);
}
else
if
(
key
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_RESHAPE_AND_CACHE
(
uint16_t
,
uint8_t
,
true
);
}
else
if
(
key
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_RESHAPE_AND_CACHE
(
__nv_bfloat16
,
uint8_t
,
true
);
}
// } else if (kv_cache_dtype == "fp8_e5m2") {
// if (key.dtype() == at::ScalarType::Float) {
// CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
// } else if (key.dtype() == at::ScalarType::Half) {
// CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
// } else if (key.dtype() == at::ScalarType::BFloat16) {
// CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
// }
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type of kv cache: "
,
kv_cache_dtype
);
}
...
...
@@ -308,13 +308,13 @@ void convert_fp8_e5m2(
CALL_CONVERT_FP8_E5M2
(
uint8_t
,
float
);
}
else
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_CONVERT_FP8_E5M2
(
uint8_t
,
uint16_t
);
//
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
//
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
}
else
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CONVERT_FP8_E5M2
(
uint8_t
,
__nv_bfloat16
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_CONVERT_FP8_E5M2
(
float
,
uint8_t
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_CONVERT_FP8_E5M2
(
uint16_t
,
uint8_t
);
//
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
//
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CONVERT_FP8_E5M2
(
__nv_bfloat16
,
uint8_t
);
}
}
csrc/dispatch_utils.h
View file @
317a7889
...
...
@@ -8,8 +8,8 @@
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
//
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
...
...
@@ -17,9 +17,9 @@
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
//
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
//
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment