Commit 318e2b5a authored by zhuwenwen's avatar zhuwenwen
Browse files

skip fp8

parent 51679bbd
...@@ -244,17 +244,17 @@ void reshape_and_cache( ...@@ -244,17 +244,17 @@ void reshape_and_cache(
CALL_RESHAPE_AND_CACHE(float, float, false); CALL_RESHAPE_AND_CACHE(float, float, false);
} else if (key.dtype() == at::ScalarType::Half) { } else if (key.dtype() == at::ScalarType::Half) {
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false); CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
} else if (key.dtype() == at::ScalarType::BFloat16) { // } else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); // CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
}
} else if (kv_cache_dtype == "fp8_e5m2") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
} else if (key.dtype() == at::ScalarType::Half) {
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
} }
// } else if (kv_cache_dtype == "fp8_e5m2") {
// if (key.dtype() == at::ScalarType::Float) {
// CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
// } else if (key.dtype() == at::ScalarType::Half) {
// CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
// } else if (key.dtype() == at::ScalarType::BFloat16) {
// CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
// }
} else { } else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
} }
...@@ -462,13 +462,13 @@ void convert_fp8_e5m2( ...@@ -462,13 +462,13 @@ void convert_fp8_e5m2(
CALL_CONVERT_FP8_E5M2(uint8_t, float); CALL_CONVERT_FP8_E5M2(uint8_t, float);
} else if (src_cache.dtype() == at::ScalarType::Half) { } else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t); CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) { // } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16); // CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
} else if (dst_cache.dtype() == at::ScalarType::Float) { } else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8_E5M2(float, uint8_t); CALL_CONVERT_FP8_E5M2(float, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::Half) { } else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t); CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) { // } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t); // CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
} }
} }
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) // AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH( \
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ // AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) // AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH( \
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "../../attention/attention_dtypes.h" #include "../../attention/attention_dtypes.h"
#include "../../attention/dtype_float32.cuh" #include "../../attention/dtype_float32.cuh"
#include "../../attention/dtype_float16.cuh" #include "../../attention/dtype_float16.cuh"
#include "../../attention/dtype_bfloat16.cuh" // #include "../../attention/dtype_bfloat16.cuh"
#pragma once #pragma once
......
...@@ -26,9 +26,9 @@ logger = init_logger(__name__) ...@@ -26,9 +26,9 @@ logger = init_logger(__name__)
STR_DTYPE_TO_TORCH_DTYPE = { STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half, "half": torch.half,
"bfloat16": torch.bfloat16, # "bfloat16": torch.bfloat16,
"float": torch.float, "float": torch.float,
"fp8_e5m2": torch.uint8, # "fp8_e5m2": torch.uint8,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment