"cacheflow/core/server.py" did not exist on "0f4b32199ec6c5d16bc03767e36fff2d54559ff8"
Unverified Commit 58c0a928 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix broken explicit unquantized kv cache dtype support (#38922)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 3dd60971
...@@ -17,6 +17,22 @@ enum class Fp8KVCacheDataType { ...@@ -17,6 +17,22 @@ enum class Fp8KVCacheDataType {
kFp8E5M2 = 2, kFp8E5M2 = 2,
}; };
inline Fp8KVCacheDataType get_fp8_kv_cache_data_type(
const std::string& dtype_str) {
// dtype_str refers to CacheDType at vllm.config.cache.CacheDType
if (dtype_str == "auto" || dtype_str == "float16" ||
dtype_str == "bfloat16") {
// unquantized kv cache
return Fp8KVCacheDataType::kAuto;
} else if (dtype_str == "fp8" || dtype_str == "fp8_ds_mla" ||
dtype_str == "fp8_e4m3") {
return Fp8KVCacheDataType::kFp8E4M3;
} else if (dtype_str == "fp8_e5m2") {
return Fp8KVCacheDataType::kFp8E5M2;
}
TORCH_CHECK(false, "Unsupported fp8 kv cache data type: ", dtype_str);
}
// fp8 vector types for quantization of kv cache // fp8 vector types for quantization of kv cache
template <> template <>
struct Vec<uint8_t, 1> { struct Vec<uint8_t, 1> {
......
...@@ -639,7 +639,9 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -639,7 +639,9 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
// function with template<typename scalar_t, typename cache_t, // function with template<typename scalar_t, typename cache_t,
// Fp8KVCacheDataType kv_dt>. // Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \ vllm::Fp8KVCacheDataType KV_CACHE_DTYPE = \
vllm::get_fp8_kv_cache_data_type(KV_DTYPE); \
if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { \
if (SRC_DTYPE == at::ScalarType::Float) { \ if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \ } else if (SRC_DTYPE == at::ScalarType::Half) { \
...@@ -649,21 +651,18 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -649,21 +651,18 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
} else { \ } else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else { \ } else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E4M3) { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == at::ScalarType::Float) { \ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == at::ScalarType::Half) { \ FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} }
} // namespace fp8 } // namespace fp8
......
...@@ -543,7 +543,9 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -543,7 +543,9 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
// function with template<typename scalar_t, typename cache_t, // function with template<typename scalar_t, typename cache_t,
// Fp8KVCacheDataType kv_dt>. // Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \ vllm::Fp8KVCacheDataType KV_CACHE_DTYPE = \
vllm::get_fp8_kv_cache_data_type(KV_DTYPE); \
if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { \
if (SRC_DTYPE == at::ScalarType::Float) { \ if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \ } else if (SRC_DTYPE == at::ScalarType::Half) { \
...@@ -553,43 +555,28 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -553,43 +555,28 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
} else { \ } else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else { \ } else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E4M3) { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ if (SRC_DTYPE == at::ScalarType::Float) { \
if (SRC_DTYPE == at::ScalarType::Float) { \ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \
} else if (SRC_DTYPE == at::ScalarType::Half) { \ FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ } else { \
} else { \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
TORCH_CHECK(false, \ } \
"Unsupported input type of kv cache: ", SRC_DTYPE); \ } else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E5M2) { \
} \ if (SRC_DTYPE == at::ScalarType::Float) { \
} else if (KV_DTYPE == "fp8_e5m2") { \ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
if (SRC_DTYPE == at::ScalarType::Float) { \ } else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "fp8_ds_mla") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} }
} // namespace fp8 } // namespace fp8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment