Commit b2d14ba3 authored by yangql's avatar yangql
Browse files

修复kvcache-fp8—e5m2的不能开cp的bug

parent bb13d854
...@@ -365,21 +365,21 @@ inline __device__ uint8_t float_to_fp8e5m2(float f) { ...@@ -365,21 +365,21 @@ inline __device__ uint8_t float_to_fp8e5m2(float f) {
// fp8 // fp8
template <typename Tin> template <typename Tin>
__inline__ __device__ uint8_t __inline__ __device__ uint8_t
scaled_vec_conversion_e5m2(const Tin& a, float scale) { scaled_vec_conversion_to_e5m2(const Tin& a, float scale) {
return 0; return 0;
} }
// float -> fp8 // float -> fp8
template <> template <>
__inline__ __device__ uint8_t __inline__ __device__ uint8_t
scaled_vec_conversion_e5m2<float>(const float& a, float scale) { scaled_vec_conversion_to_e5m2<float>(const float& a, float scale) {
return float_to_fp8e5m2(a / scale); return float_to_fp8e5m2(a / scale);
} }
// half -> fp8 // half -> fp8
template <> template <>
__inline__ __device__ uint8_t __inline__ __device__ uint8_t
scaled_vec_conversion_e5m2<uint16_t>(const uint16_t& a, float scale) { scaled_vec_conversion_to_e5m2<uint16_t>(const uint16_t& a, float scale) {
float res_f = half_to_float(a) / scale; float res_f = half_to_float(a) / scale;
return float_to_fp8e5m2(res_f); return float_to_fp8e5m2(res_f);
} }
...@@ -387,11 +387,49 @@ scaled_vec_conversion_e5m2<uint16_t>(const uint16_t& a, float scale) { ...@@ -387,11 +387,49 @@ scaled_vec_conversion_e5m2<uint16_t>(const uint16_t& a, float scale) {
// bf16 -> fp8 // bf16 -> fp8
template <> template <>
__inline__ __device__ uint8_t __inline__ __device__ uint8_t
scaled_vec_conversion_e5m2<__nv_bfloat16>(const __nv_bfloat16& a, float scale) { scaled_vec_conversion_to_e5m2<__nv_bfloat16>(const __nv_bfloat16& a, float scale) {
float res_f = (static_cast<float>(a)) / scale; float res_f = (static_cast<float>(a)) / scale;
return float_to_fp8e5m2(res_f); return float_to_fp8e5m2(res_f);
} }
inline __device__ float fp8e5m2_to_fp32(const uint8_t& input) {
union uf16{
uint16_t as_bits;
_Float16 as_value;
} ;
uf16 u16;
u16.as_bits = (uint16_t)input << 8;
return (float)u16.as_value;
}
template <typename Tout>
__inline__ __device__ Tout
scaled_vec_conversion_from_e5m2(const uint8_t& a, float scale) {
return 0;
}
// fp8 -> float
template <>
__inline__ __device__ float
scaled_vec_conversion_from_e5m2<float>(const uint8_t& a, float scale) {
return fp8e5m2_to_fp32(a)*scale;
}
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion_from_e5m2<uint16_t>(const uint8_t& a, float scale) {
return float_to_half(fp8e5m2_to_fp32(a)*scale);
}
// fp8 -> bf16
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion_from_e5m2<__nv_bfloat16>(const uint8_t& a, float scale) {
return __float2bfloat16(fp8e5m2_to_fp32(a)*scale);
}
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
...@@ -399,12 +437,11 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -399,12 +437,11 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
return scaled_vec_conversion<Tout, Tin>(x, scale); return scaled_vec_conversion<Tout, Tin>(x, scale);
} }
else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tout)==1){ else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tout)==1){
return scaled_vec_conversion_e5m2<Tin>(x, scale); return scaled_vec_conversion_to_e5m2<Tin>(x, scale);
}
else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tin)==1){
return scaled_vec_conversion_from_e5m2<Tout>(x, scale);
} }
// else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 &&
// (std::is_same<Tin, uint16_t>::value||std::is_same<Tin, __nv_bfloat16>::value)){
// return scaled_vec_conversion_e5m2<Tin>(x, scale);
// }
return {}; // Squash missing return statement warning return {}; // Squash missing return statement warning
} }
......
...@@ -2166,12 +2166,15 @@ def gather_cache(src_cache: torch.Tensor, ...@@ -2166,12 +2166,15 @@ def gather_cache(src_cache: torch.Tensor,
kv_dtype = "auto", kv_dtype = "auto",
scale: float = 1.0, scale: float = 1.0,
) -> None: ) -> None:
#支持"kv cache fp8" #支持"kv cache fp8" 临时方案,带dtype的gather_cache在vllm0.10后会实现。
if kv_dtype == "fp8": if kv_dtype == "fp8" or kv_dtype == "fp8_e5m2" or kv_dtype == "fp8_e4m3":
dst_fp8 = torch.zeros(dst.shape, dtype=torch.uint8, device=dst.device) dst_fp8 = torch.empty(dst.shape, dtype=torch.uint8, device=dst.device)
convert_fp8(dst_fp8, dst, scale, kv_dtype) #convert_fp8(dst_fp8, dst, scale, kv_dtype)
torch.ops._C_cache_ops.gather_cache(src_cache, dst_fp8, block_table, torch.ops._C_cache_ops.gather_cache(src_cache, dst_fp8, block_table,
cu_seq_lens, batch_size, seq_starts) cu_seq_lens, batch_size, seq_starts)
#dst_fp8->bf16
convert_fp8(dst, dst_fp8, scale, kv_dtype)
else: else:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts) cu_seq_lens, batch_size, seq_starts)
......
...@@ -211,9 +211,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -211,9 +211,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl") "FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
if self.kv_cache_dtype != "fp8": if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
raise NotImplementedError( return
"FlashMLA with other KV cache not yet supported") raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
def _forward_decode( def _forward_decode(
self, self,
......
...@@ -24,6 +24,11 @@ from vllm.platforms import _Backend, current_platform ...@@ -24,6 +24,11 @@ from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import validate_kv_sharing_target from vllm.v1.attention.backends.utils import validate_kv_sharing_target
USE_XFORMERS_OPS = None
try:
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
except AttributeError:
tag_cudagraph_unsafe = () # type: ignore[assignment]
class Attention(nn.Module): class Attention(nn.Module):
"""Attention layer. """Attention layer.
...@@ -204,10 +209,12 @@ class Attention(nn.Module): ...@@ -204,10 +209,12 @@ class Attention(nn.Module):
`vllm.forward_context.get_forward_context().attn_metadata`. `vllm.forward_context.get_forward_context().attn_metadata`.
""" """
if self.calculate_kv_scales: if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata # attn_metadata = get_forward_context().attn_metadata
if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)): # #if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
# if key is not None and value is not None: # if key is not None and value is not None:
self.calc_kv_scales(query, key, value) # self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name)
if self.use_output: if self.use_output:
output_shape = (output_shape output_shape = (output_shape
if output_shape is not None else query.shape) if output_shape is not None else query.shape)
...@@ -395,7 +402,42 @@ def maybe_save_kv_layer_to_connector( ...@@ -395,7 +402,42 @@ def maybe_save_kv_layer_to_connector(
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer, connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name]) attn_metadata[layer_name])
def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
# if attn_metadata is None or not getattr(
# attn_metadata, 'enable_kv_scales_calculation', False):
# return
self = forward_context.no_compile_layers[layer_name]
self.calc_kv_scales(query, key, value)
def maybe_calc_kv_scales_fake( query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales,
mutates_args=[],
fake_impl=maybe_calc_kv_scales_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,)
def unified_attention( def unified_attention(
query: torch.Tensor, query: torch.Tensor,
......
...@@ -99,7 +99,8 @@ def flash_mla_with_kvcache( ...@@ -99,7 +99,8 @@ def flash_mla_with_kvcache(
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5) softmax_scale = q.shape[-1]**(-0.5)
if current_platform.is_rocm(): if current_platform.is_rocm():
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q, q,
k_cache, k_cache,
...@@ -112,7 +113,7 @@ def flash_mla_with_kvcache( ...@@ -112,7 +113,7 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata, tile_scheduler_metadata,
num_splits, num_splits,
k_scale, k_scale,
"fp8_e4m3", kv_dtype,
) )
return out, softmax_lse return out, softmax_lse
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
......
...@@ -183,8 +183,8 @@ STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -183,8 +183,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
"float": torch.float, "float": torch.float,
"fp8": torch.uint8, "fp8": torch.uint8,
# "fp8_e4m3": torch.uint8, "fp8_e4m3": torch.uint8,
# "fp8_e5m2": torch.uint8, "fp8_e5m2": torch.uint8,
"int8": torch.int8, "int8": torch.int8,
} }
......
...@@ -150,9 +150,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -150,9 +150,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl") "FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
if self.kv_cache_dtype != "fp8": if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
raise NotImplementedError( return
"FlashMLA with other KV cache not yet supported") raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
def _forward_decode( def _forward_decode(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment