Commit 50f7ea0f authored by linhai1's avatar linhai1
Browse files

support fp8_e4m3.

parents 484c5433 6741925c
...@@ -47,6 +47,16 @@ class VllmMLADecodeMetadata: ...@@ -47,6 +47,16 @@ class VllmMLADecodeMetadata:
num_splits: Optional[torch.Tensor] = None num_splits: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None block_kv_indices: Optional[torch.Tensor] = None
def __init__(
self,
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
num_splits: Optional[torch.Tensor] = None,
block_kv_indices: Optional[torch.Tensor] = None,
):
self.flashmla_metadata = flashmla_metadata
self.num_splits = num_splits
self.block_kv_indices = block_kv_indices
class DCUMLABackend(AttentionBackend): class DCUMLABackend(AttentionBackend):
def __init__( def __init__(
...@@ -86,14 +96,6 @@ class DCUMLABackend(AttentionBackend): ...@@ -86,14 +96,6 @@ class DCUMLABackend(AttentionBackend):
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
if not skip_prefill: if not skip_prefill:
# 先用triton backend,后面考虑替换
# from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
# self.triton_backend = TritonAttnBackend(
# model_runner,
# skip_prefill=False,
# kv_indptr_buf=kv_indptr_buf,
# )
# prefill改用flash attn
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
self.flashattn_backend = FlashAttentionBackend( self.flashattn_backend = FlashAttentionBackend(
model_runner, model_runner,
...@@ -147,9 +149,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -147,9 +149,7 @@ class DCUMLABackend(AttentionBackend):
mla_metadata, num_splits_t, block_kv_indices mla_metadata, num_splits_t, block_kv_indices
) )
else: else:
# prefill/extend用triton backend -> 改用flash attn
if not self.skip_prefill: if not self.skip_prefill:
# self.triton_backend.init_forward_metadata(forward_batch)
self.flashattn_backend.init_forward_metadata(forward_batch) self.flashattn_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state( def init_cuda_graph_state(
...@@ -241,15 +241,6 @@ class DCUMLABackend(AttentionBackend): ...@@ -241,15 +241,6 @@ class DCUMLABackend(AttentionBackend):
) )
else: else:
if not self.skip_prefill: if not self.skip_prefill:
# self.triton_backend.init_forward_metadata_capture_cuda_graph(
# bs,
# num_tokens,
# req_pool_indices,
# seq_lens,
# encoder_lens,
# forward_mode,
# spec_info,
# )
self.flashattn_backend.init_forward_metadata_capture_cuda_graph( self.flashattn_backend.init_forward_metadata_capture_cuda_graph(
bs, bs,
num_tokens, num_tokens,
...@@ -321,16 +312,6 @@ class DCUMLABackend(AttentionBackend): ...@@ -321,16 +312,6 @@ class DCUMLABackend(AttentionBackend):
] ]
else: else:
if not self.skip_prefill: if not self.skip_prefill:
# self.triton_backend.init_forward_metadata_replay_cuda_graph(
# bs,
# req_pool_indices,
# seq_lens,
# seq_lens_sum,
# encoder_lens,
# forward_mode,
# spec_info,
# seq_lens_cpu,
# )
self.flashattn_backend.init_forward_metadata_replay_cuda_graph( self.flashattn_backend.init_forward_metadata_replay_cuda_graph(
bs, bs,
req_pool_indices, req_pool_indices,
...@@ -413,6 +394,10 @@ class DCUMLABackend(AttentionBackend): ...@@ -413,6 +394,10 @@ class DCUMLABackend(AttentionBackend):
getattr(torch, "float8_e5m2", None), getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None), getattr(torch, "float8_e5m2fnuz", None),
): ):
if k_cache_reshaped.dtype == torch.float8_e4m3fnuz:
kv_cache_dtype="fp8_e4m3"
elif k_cache_reshaped.dtype == torch.float8_e5m2fnuz:
kv_cache_dtype="fp8_e5m2"
k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device) k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device)
o = self._call_fp8_decode( o = self._call_fp8_decode(
reshape_q, reshape_q,
...@@ -421,7 +406,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -421,7 +406,7 @@ class DCUMLABackend(AttentionBackend):
forward_batch.seq_lens.to(torch.int32), forward_batch.seq_lens.to(torch.int32),
layer.scaling, layer.scaling,
k_scale.to(torch.float32), k_scale.to(torch.float32),
kv_cache_dtype="fp8_e4m3", kv_cache_dtype=kv_cache_dtype,
) )
else: else:
o = self._call_decode( o = self._call_decode(
...@@ -442,7 +427,6 @@ class DCUMLABackend(AttentionBackend): ...@@ -442,7 +427,6 @@ class DCUMLABackend(AttentionBackend):
layer: "RadixAttention", layer: "RadixAttention",
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True, save_kv_cache: bool = True,
# For multi-head latent attention # For multi-head latent attention
q_rope: Optional[torch.Tensor] = None, q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
...@@ -455,11 +439,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -455,11 +439,7 @@ class DCUMLABackend(AttentionBackend):
forward_batch.forward_mode == ForwardMode.EXTEND forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND) or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND)
): ):
# flash_attn不支持fp8,fp8无法正常执行extend
if not self.skip_prefill: if not self.skip_prefill:
# return self.triton_backend.forward_extend(
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
return self.flashattn_backend.forward_extend( return self.flashattn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope, sinks q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope, sinks
) )
...@@ -484,6 +464,10 @@ class DCUMLABackend(AttentionBackend): ...@@ -484,6 +464,10 @@ class DCUMLABackend(AttentionBackend):
getattr(torch, "float8_e5m2", None), getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None), getattr(torch, "float8_e5m2fnuz", None),
): ):
if k_cache_reshaped.dtype == torch.float8_e4m3fnuz:
k_cache_reshaped = k_cache_reshaped.view(torch.float8_e4m3fn)
elif k_cache_reshaped.dtype == torch.float8_e5m2fnuz:
k_cache_reshaped = k_cache_reshaped.view(torch.float8_e5m2)
k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device) k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device)
o = self._call_fp8_decode( o = self._call_fp8_decode(
reshape_q, reshape_q,
......
...@@ -858,11 +858,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -858,11 +858,6 @@ class FlashAttentionBackend(AttentionBackend):
**kwargs, **kwargs,
) )
else: else:
# MHA for extend part of sequence without attending prefix kv cache
# if layer.layer_id == 0:
# print("q.dtype, k.shape, v.shape, k.dtype, v.dtype, layer.k_scale.shape, layer.k_scale.dtype, layer.v_scale.shape, layer.v_scale.dtype, \n",
# q.dtype, k.shape, v.shape, k.dtype, v.dtype, )
# print("layer info: \n", layer)
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim).to(data_dtype), q=q.view(-1, layer.tp_q_head_num, layer.head_dim).to(data_dtype),
k=(k.view(-1, layer.tp_k_head_num, layer.head_dim) * k_descale).to(data_dtype), k=(k.view(-1, layer.tp_k_head_num, layer.head_dim) * k_descale).to(data_dtype),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment