"wrappers/python/vscode:/vscode.git/clone" did not exist on "6c82c11ba6feb50953e33131072a5fe0e925422c"
Commit b2d14ba3 authored by yangql's avatar yangql
Browse files

修复kvcache-fp8—e5m2的不能开cp的bug

parent bb13d854
......@@ -365,21 +365,21 @@ inline __device__ uint8_t float_to_fp8e5m2(float f) {
// fp8
template <typename Tin>
__inline__ __device__ uint8_t
scaled_vec_conversion_e5m2(const Tin& a, float scale) {
scaled_vec_conversion_to_e5m2(const Tin& a, float scale) {
return 0;
}
// float -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_e5m2<float>(const float& a, float scale) {
scaled_vec_conversion_to_e5m2<float>(const float& a, float scale) {
return float_to_fp8e5m2(a / scale);
}
// half -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_e5m2<uint16_t>(const uint16_t& a, float scale) {
scaled_vec_conversion_to_e5m2<uint16_t>(const uint16_t& a, float scale) {
float res_f = half_to_float(a) / scale;
return float_to_fp8e5m2(res_f);
}
......@@ -387,11 +387,49 @@ scaled_vec_conversion_e5m2<uint16_t>(const uint16_t& a, float scale) {
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_e5m2<__nv_bfloat16>(const __nv_bfloat16& a, float scale) {
scaled_vec_conversion_to_e5m2<__nv_bfloat16>(const __nv_bfloat16& a, float scale) {
float res_f = (static_cast<float>(a)) / scale;
return float_to_fp8e5m2(res_f);
}
inline __device__ float fp8e5m2_to_fp32(const uint8_t& input) {
union uf16{
uint16_t as_bits;
_Float16 as_value;
} ;
uf16 u16;
u16.as_bits = (uint16_t)input << 8;
return (float)u16.as_value;
}
template <typename Tout>
__inline__ __device__ Tout
scaled_vec_conversion_from_e5m2(const uint8_t& a, float scale) {
return 0;
}
// fp8 -> float
template <>
__inline__ __device__ float
scaled_vec_conversion_from_e5m2<float>(const uint8_t& a, float scale) {
return fp8e5m2_to_fp32(a)*scale;
}
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion_from_e5m2<uint16_t>(const uint8_t& a, float scale) {
return float_to_half(fp8e5m2_to_fp32(a)*scale);
}
// fp8 -> bf16
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion_from_e5m2<__nv_bfloat16>(const uint8_t& a, float scale) {
return __float2bfloat16(fp8e5m2_to_fp32(a)*scale);
}
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
......@@ -399,12 +437,11 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
return scaled_vec_conversion<Tout, Tin>(x, scale);
}
else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tout)==1){
return scaled_vec_conversion_e5m2<Tin>(x, scale);
return scaled_vec_conversion_to_e5m2<Tin>(x, scale);
}
else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tin)==1){
return scaled_vec_conversion_from_e5m2<Tout>(x, scale);
}
// else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 &&
// (std::is_same<Tin, uint16_t>::value||std::is_same<Tin, __nv_bfloat16>::value)){
// return scaled_vec_conversion_e5m2<Tin>(x, scale);
// }
return {}; // Squash missing return statement warning
}
......
......@@ -2166,12 +2166,15 @@ def gather_cache(src_cache: torch.Tensor,
kv_dtype = "auto",
scale: float = 1.0,
) -> None:
#支持"kv cache fp8"
if kv_dtype == "fp8":
dst_fp8 = torch.zeros(dst.shape, dtype=torch.uint8, device=dst.device)
convert_fp8(dst_fp8, dst, scale, kv_dtype)
#支持"kv cache fp8" 临时方案,带dtype的gather_cache在vllm0.10后会实现。
if kv_dtype == "fp8" or kv_dtype == "fp8_e5m2" or kv_dtype == "fp8_e4m3":
dst_fp8 = torch.empty(dst.shape, dtype=torch.uint8, device=dst.device)
#convert_fp8(dst_fp8, dst, scale, kv_dtype)
torch.ops._C_cache_ops.gather_cache(src_cache, dst_fp8, block_table,
cu_seq_lens, batch_size, seq_starts)
#dst_fp8->bf16
convert_fp8(dst, dst_fp8, scale, kv_dtype)
else:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
......
......@@ -211,9 +211,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
if self.kv_cache_dtype != "fp8":
raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
return
raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
def _forward_decode(
self,
......
......@@ -24,6 +24,11 @@ from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
USE_XFORMERS_OPS = None
try:
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
except AttributeError:
tag_cudagraph_unsafe = () # type: ignore[assignment]
class Attention(nn.Module):
"""Attention layer.
......@@ -204,10 +209,12 @@ class Attention(nn.Module):
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata
if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
# attn_metadata = get_forward_context().attn_metadata
# #if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
# if key is not None and value is not None:
self.calc_kv_scales(query, key, value)
# self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name)
if self.use_output:
output_shape = (output_shape
if output_shape is not None else query.shape)
......@@ -395,7 +402,42 @@ def maybe_save_kv_layer_to_connector(
assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
# if attn_metadata is None or not getattr(
# attn_metadata, 'enable_kv_scales_calculation', False):
# return
self = forward_context.no_compile_layers[layer_name]
self.calc_kv_scales(query, key, value)
def maybe_calc_kv_scales_fake( query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales,
mutates_args=[],
fake_impl=maybe_calc_kv_scales_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,)
def unified_attention(
query: torch.Tensor,
......
......@@ -99,7 +99,8 @@ def flash_mla_with_kvcache(
if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5)
if current_platform.is_rocm():
if kv_cache_dtype == "fp8":
if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q,
k_cache,
......@@ -112,7 +113,7 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata,
num_splits,
k_scale,
"fp8_e4m3",
kv_dtype,
)
return out, softmax_lse
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
......
......@@ -183,8 +183,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.uint8,
# "fp8_e4m3": torch.uint8,
# "fp8_e5m2": torch.uint8,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
"int8": torch.int8,
}
......
......@@ -150,9 +150,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
if self.kv_cache_dtype != "fp8":
raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
return
raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
def _forward_decode(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment