Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
58c0a928
Unverified
Commit
58c0a928
authored
Apr 10, 2026
by
Isotr0py
Committed by
GitHub
Apr 09, 2026
Browse files
[Bugfix] Fix broken explicit unquantized kv cache dtype support (#38922)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
3dd60971
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
50 deletions
+52
-50
csrc/attention/dtype_fp8.cuh
csrc/attention/dtype_fp8.cuh
+16
-0
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
+13
-14
csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh
csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh
+23
-36
No files found.
csrc/attention/dtype_fp8.cuh
View file @
58c0a928
...
@@ -17,6 +17,22 @@ enum class Fp8KVCacheDataType {
...
@@ -17,6 +17,22 @@ enum class Fp8KVCacheDataType {
kFp8E5M2
=
2
,
kFp8E5M2
=
2
,
};
};
inline
Fp8KVCacheDataType
get_fp8_kv_cache_data_type
(
const
std
::
string
&
dtype_str
)
{
// dtype_str refers to CacheDType at vllm.config.cache.CacheDType
if
(
dtype_str
==
"auto"
||
dtype_str
==
"float16"
||
dtype_str
==
"bfloat16"
)
{
// unquantized kv cache
return
Fp8KVCacheDataType
::
kAuto
;
}
else
if
(
dtype_str
==
"fp8"
||
dtype_str
==
"fp8_ds_mla"
||
dtype_str
==
"fp8_e4m3"
)
{
return
Fp8KVCacheDataType
::
kFp8E4M3
;
}
else
if
(
dtype_str
==
"fp8_e5m2"
)
{
return
Fp8KVCacheDataType
::
kFp8E5M2
;
}
TORCH_CHECK
(
false
,
"Unsupported fp8 kv cache data type: "
,
dtype_str
);
}
// fp8 vector types for quantization of kv cache
// fp8 vector types for quantization of kv cache
template
<
>
template
<
>
struct
Vec
<
uint8_t
,
1
>
{
struct
Vec
<
uint8_t
,
1
>
{
...
...
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
View file @
58c0a928
...
@@ -639,7 +639,9 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
...
@@ -639,7 +639,9 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
// function with template<typename scalar_t, typename cache_t,
// function with template<typename scalar_t, typename cache_t,
// Fp8KVCacheDataType kv_dt>.
// Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \
vllm::Fp8KVCacheDataType KV_CACHE_DTYPE = \
vllm::get_fp8_kv_cache_data_type(KV_DTYPE); \
if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
...
@@ -649,21 +651,18 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
...
@@ -649,21 +651,18 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
} else { \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} \
} else { \
} else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E4M3) { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
} else { \
TORCH_CHECK(false, "Unsupported
data
type of kv cache: ",
KV
_DTYPE);
\
TORCH_CHECK(false, "Unsupported
input
type of kv cache: ",
SRC
_DTYPE); \
} \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
}
}
}
// namespace fp8
}
// namespace fp8
...
...
csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh
View file @
58c0a928
...
@@ -543,7 +543,9 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
...
@@ -543,7 +543,9 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
// function with template<typename scalar_t, typename cache_t,
// function with template<typename scalar_t, typename cache_t,
// Fp8KVCacheDataType kv_dt>.
// Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \
vllm::Fp8KVCacheDataType KV_CACHE_DTYPE = \
vllm::get_fp8_kv_cache_data_type(KV_DTYPE); \
if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
...
@@ -553,43 +555,28 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
...
@@ -553,43 +555,28 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
} else { \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} \
} else { \
} else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E4M3) { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
TORCH_CHECK(false, \
} \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E5M2) { \
} \
if (SRC_DTYPE == at::ScalarType::Float) { \
} else if (KV_DTYPE == "fp8_e5m2") { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
if (SRC_DTYPE == at::ScalarType::Float) { \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "fp8_ds_mla") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
} else { \
TORCH_CHECK(false, "Unsupported
data
type of kv cache: ",
KV
_DTYPE);
\
TORCH_CHECK(false, "Unsupported
input
type of kv cache: ",
SRC
_DTYPE); \
} \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
}
}
}
// namespace fp8
}
// namespace fp8
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment