Commit 73b1705e authored by zhuwenwen's avatar zhuwenwen
Browse files

add fp8

parent dd823f7f
...@@ -739,16 +739,16 @@ void paged_attention_v1( ...@@ -739,16 +739,16 @@ void paged_attention_v1(
} else { } else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
} }
// } else if (kv_cache_dtype == "fp8_e5m2") { } else if (kv_cache_dtype == "fp8_e5m2") {
// if (query.dtype() == at::ScalarType::Float) { if (query.dtype() == at::ScalarType::Float) {
// CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
// } else if (query.dtype() == at::ScalarType::Half) { } else if (query.dtype() == at::ScalarType::Half) {
// CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
// } else if (query.dtype() == at::ScalarType::BFloat16) { } else if (query.dtype() == at::ScalarType::BFloat16) {
// CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
// } else { } else {
// TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
// } }
} else { } else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
} }
...@@ -932,16 +932,16 @@ void paged_attention_v2( ...@@ -932,16 +932,16 @@ void paged_attention_v2(
} else { } else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
} }
// } else if (kv_cache_dtype == "fp8_e5m2") { } else if (kv_cache_dtype == "fp8_e5m2") {
// if (query.dtype() == at::ScalarType::Float) { if (query.dtype() == at::ScalarType::Float) {
// CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
// } else if (query.dtype() == at::ScalarType::Half) { } else if (query.dtype() == at::ScalarType::Half) {
// CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
// } else if (query.dtype() == at::ScalarType::BFloat16) { } else if (query.dtype() == at::ScalarType::BFloat16) {
// CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
// } else { } else {
// TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
// } }
} else { } else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
} }
......
...@@ -254,14 +254,14 @@ void reshape_and_cache( ...@@ -254,14 +254,14 @@ void reshape_and_cache(
} else if (key.dtype() == at::ScalarType::BFloat16) { } else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
} }
// } else if (kv_cache_dtype == "fp8_e5m2") { } else if (kv_cache_dtype == "fp8_e5m2") {
// if (key.dtype() == at::ScalarType::Float) { if (key.dtype() == at::ScalarType::Float) {
// CALL_RESHAPE_AND_CACHE(float, uint8_t, true); CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
// } else if (key.dtype() == at::ScalarType::Half) { } else if (key.dtype() == at::ScalarType::Half) {
// CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
// } else if (key.dtype() == at::ScalarType::BFloat16) { } else if (key.dtype() == at::ScalarType::BFloat16) {
// CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true); CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
// } }
} else { } else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
} }
...@@ -314,7 +314,7 @@ void convert_fp8_e5m2( ...@@ -314,7 +314,7 @@ void convert_fp8_e5m2(
CALL_CONVERT_FP8_E5M2(float, uint8_t); CALL_CONVERT_FP8_E5M2(float, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::Half) { } else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t); CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
// } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
// CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t); CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
} }
} }
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
// AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH( \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment