Commit 484c5433 authored by linhai1's avatar linhai1
Browse files

support fp8_e4m3.

parent 93eb92f8
......@@ -2,7 +2,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
import torch
import triton
......@@ -363,7 +363,7 @@ class DCUMLABackend(AttentionBackend):
def _call_fp8_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor,
block_table: torch.Tensor, cache_seqlens: torch.Tensor,
scaling: float):
scaling: float, k_scale=None, kv_cache_dtype=None):
assert _has_flash_mla, "FP8 KV cache 需要flash_mla包"
o, _ = flash_mla_with_kvcache_quantization(
q=reshape_q,
......@@ -375,7 +375,8 @@ class DCUMLABackend(AttentionBackend):
num_splits=self.forward_metadata.num_splits,
softmax_scale=scaling,
causal=True,
is_fp8_kvcache=True,
k_scale=k_scale,
kv_cache_dtype=kv_cache_dtype,
)
return o
......@@ -412,14 +413,23 @@ class DCUMLABackend(AttentionBackend):
getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None),
):
k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device)
o = self._call_fp8_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32), layer.scaling,
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32),
layer.scaling,
k_scale.to(torch.float32),
kv_cache_dtype="fp8_e4m3",
)
else:
o = self._call_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32), layer.scaling,
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32),
layer.scaling,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
......@@ -474,14 +484,21 @@ class DCUMLABackend(AttentionBackend):
getattr(torch, "float8_e5m2", None),
getattr(torch, "float8_e5m2fnuz", None),
):
k_scale = layer.k_scale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=reshape_q.device)
o = self._call_fp8_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
layer.scaling,
k_scale.to(torch.float32),
kv_cache_dtype=self.data_type,
)
else:
o = self._call_decode(
reshape_q, k_cache_reshaped, self.forward_metadata.block_kv_indices[:bs],
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
layer.scaling,
)
......@@ -489,3 +506,95 @@ class DCUMLABackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
class DCUMLAMultiStepDraftBackend:
"""
Wrap multiple flashmla attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
if topk > 1:
raise ValueError(
"Currently FlashMLA only supports topk=1 for speculative decoding"
)
self.topk = topk
self.speculative_num_steps = speculative_num_steps
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.attn_backends = []
for i in range(self.speculative_num_steps - 1):
self.attn_backends.append(
DCUMLABackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
kv_last_page_len_buf=None,
)
)
def common_template(
self,
forward_batch: ForwardBatch,
call_fn: Callable,
):
assert forward_batch.spec_info is not None
for i in range(self.speculative_num_steps - 1):
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, block_kv_indices=None
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)
self.common_template(forward_batch, call_fn)
......@@ -695,6 +695,7 @@ class FlashAttentionBackend(AttentionBackend):
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
data_dtype = q.dtype
if (
self.kv_cache_dtype_str != "auto"
and layer.head_dim <= 256
......@@ -829,6 +830,8 @@ class FlashAttentionBackend(AttentionBackend):
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
k_descale = k_descale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=q.device)
v_descale = v_descale if layer.k_scale is not None else torch.tensor([1.0], dtype=torch.float32, device=q.device)
# Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache:
assert not get_global_server_args().disable_chunked_prefix_cache
......@@ -842,9 +845,9 @@ class FlashAttentionBackend(AttentionBackend):
assert forward_batch.mha_return_lse
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
q=q.view(-1, layer.tp_q_head_num, layer.head_dim).to(data_dtype),
k=(k.view(-1, layer.tp_k_head_num, layer.head_dim) * k_descale).to(data_dtype),
v=(v.view(-1, layer.tp_k_head_num, layer.v_head_dim) * v_descale).to(data_dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q,
......@@ -856,10 +859,14 @@ class FlashAttentionBackend(AttentionBackend):
)
else:
# MHA for extend part of sequence without attending prefix kv cache
# if layer.layer_id == 0:
# print("q.dtype, k.shape, v.shape, k.dtype, v.dtype, layer.k_scale.shape, layer.k_scale.dtype, layer.v_scale.shape, layer.v_scale.dtype, \n",
# q.dtype, k.shape, v.shape, k.dtype, v.dtype, )
# print("layer info: \n", layer)
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
q=q.view(-1, layer.tp_q_head_num, layer.head_dim).to(data_dtype),
k=(k.view(-1, layer.tp_k_head_num, layer.head_dim) * k_descale).to(data_dtype),
v=(v.view(-1, layer.tp_k_head_num, layer.v_head_dim) * v_descale).to(data_dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q,
max_seqlen_q=metadata.max_seq_len_q,
......
......@@ -2296,7 +2296,7 @@ class DeepseekV2AttentionMLA(nn.Module):
# Fetch latent cache from memory pool with precomputed chunked kv indices
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
self.attn_mha.layer_id
)
).to(q.dtype)
latent_cache = (
latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
.contiguous()
......
......@@ -46,6 +46,7 @@ class DraftBackendFactory:
else self._create_triton_decode_backend
),
"flashmla": self._create_flashmla_decode_backend,
"dcu_mla": self._create_dcumla_decode_backend,
"trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend,
"nsa": self._create_nsa_decode_backend,
......@@ -69,6 +70,7 @@ class DraftBackendFactory:
else self._create_triton_prefill_backend
),
"flashmla": self._create_flashmla_prefill_backend,
"dcu_mla": self._create_dcumla_prefill_backend,
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
"nsa": self._create_nsa_prefill_backend,
......@@ -150,6 +152,15 @@ class DraftBackendFactory:
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_dcumla_decode_backend(self):
from sglang.srt.layers.attention.dcu_mla_backend import (
DCUMLAMultiStepDraftBackend,
)
return DCUMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_trtllm_mha_decode_backend(self):
from sglang.srt.layers.attention.trtllm_mha_backend import (
TRTLLMHAAttnMultiStepDraftBackend,
......@@ -224,3 +235,9 @@ class DraftBackendFactory:
"flashmla prefill backend is not yet supported for draft extend."
)
return None
def _create_dcumla_prefill_backend(self):
logger.warning(
"flashmla prefill backend is not yet supported for draft extend."
)
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment