Unverified Commit b1100846 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Refactor flashinfer logic for deepseek v3 and fix accuracy bug (#3785)

parent 27a46317
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import json import json
import logging import logging
import math
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import List, Optional, Set, Union from typing import List, Optional, Set, Union
...@@ -103,7 +104,20 @@ class ModelConfig: ...@@ -103,7 +104,20 @@ class ModelConfig:
self.head_dim = 256 self.head_dim = 256
self.attention_arch = AttentionArch.MLA self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
# Handle rope scaling with yarn
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
if self.hf_config.rope_scaling:
mscale_all_dim = self.hf_config.rope_scaling.get(
"mscale_all_dim", False
)
scaling_factor = self.hf_config.rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
elif "MiniCPM3ForCausalLM" in self.hf_config.architectures: elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
self.head_dim = 128 self.head_dim = 128
self.attention_arch = AttentionArch.MLA self.attention_arch = AttentionArch.MLA
...@@ -414,3 +428,9 @@ def is_multimodal_model(model_architectures: List[str]): ...@@ -414,3 +428,9 @@ def is_multimodal_model(model_architectures: List[str]):
def is_encoder_decoder_model(model_architectures: List[str]): def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures return "MllamaForConditionalGeneration" in model_architectures
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
from __future__ import annotations
"""
Support attention backend for flashinfer MLA.
When radix cache is enabled, the backend only uses BatchMLAPaged wrapper when forwarding.
When radix cache is disabled, the backend uses BatchPrefill wrappers for prefilling (with or without prefix cache),
and uses BatchMLAPaged wrapper for decoding.
More details can be found in https://docs.flashinfer.ai/api/mla.html
"""
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
import torch
from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
should_use_tensor_core,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
if is_flashinfer_available():
from flashinfer import (
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.mla import BatchMLAPagedAttentionWrapper
@dataclass
class DecodeMetadata:
decode_wrapper: BatchMLAPagedAttentionWrapper
@dataclass
class PrefillMetadata:
prefill_wrapper: Union[
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
]
use_ragged: bool
# Reuse this workspace buffer across all flashinfer wrappers
global_workspace_buffer = None
class FlashInferMLAAttnBackend(AttentionBackend):
"""Flashinfer attention kernels."""
def __init__(
self,
model_runner: ModelRunner,
kv_indptr_buf: Optional[torch.Tensor] = None,
):
super().__init__()
# Parse constants
self.max_context_len = model_runner.model_config.context_len
global_config.enable_flashinfer_mla = True
# Allocate buffers
global global_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
else:
self.kv_indptr = kv_indptr_buf
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
if not global_server_args_dict["disable_radix_cache"]:
# use mla paged prefill
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="auto",
)
else:
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
backend="auto",
)
self.decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer, backend="auto"
)
# Create indices updater
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
model_runner, self
)
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
model_runner, self
)
# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {}
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
decode_wrapper=self.decode_wrapper,
)
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
else:
prefix_lens = forward_batch.extend_prefix_lens
use_ragged = global_server_args_dict["disable_radix_cache"]
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens,
prefill_wrapper_paged=self.prefill_wrapper_paged,
use_ragged=use_ragged,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrapper_paged, use_ragged
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
):
if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len,),
dtype=torch.int32,
device="cuda",
)
else:
cuda_graph_kv_indices = kv_indices_buf
self.cuda_graph_kv_indices = cuda_graph_kv_indices
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device="cuda",
)
self.cuda_graph_qk_indptr = self.kv_indptr.clone()
self.cuda_graph_qo_indptr = self.kv_indptr.clone()
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.qo_indptr[: num_tokens + 1],
kv_indptr=self.kv_indptr[: num_tokens + 1],
kv_indices=self.cuda_graph_kv_indices,
kv_len_arr=self.kv_last_page_len[:num_tokens],
backend="auto",
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrapper=decode_wrapper,
)
self.decode_cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = DecodeMetadata(decode_wrapper)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
decode_wrapper=self.decode_cuda_graph_metadata[bs],
)
else:
raise ValueError(f"Invalid forward mode: {forward_mode=}")
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
cache_loc = forward_batch.out_cache_loc
logits_soft_cap = layer.logit_cap
if not global_server_args_dict["disable_radix_cache"]:
# use mla paged prefill
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
o = prefill_wrapper_paged.run(
qall[:, :, : layer.v_head_dim],
qall[:, :, layer.v_head_dim :],
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
)
else:
# use mla ragged prefill
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
# FIXME: Here should be another prefill_paged to call
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
decode_wrapper = self.forward_metadata.decode_wrapper
cache_loc = forward_batch.out_cache_loc
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
o = decode_wrapper.run(
reshaped_q[:, :, : layer.v_head_dim],
reshaped_q[:, :, layer.v_head_dim :],
reshaped_k[:, :, : layer.v_head_dim],
reshaped_k[:, :, layer.v_head_dim :],
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
class FlashInferMLAIndicesUpdaterDecode:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants
self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.kv_cache_dtype
self.attn_backend = attn_backend
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token
def update(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper,
):
decode_wrappers = decode_wrapper or self.decode_wrapper
self.call_begin_forward(
decode_wrapper,
req_pool_indices,
seq_lens,
seq_lens_sum,
self.kv_indptr,
)
def call_begin_forward(
self,
wrapper: BatchMLAPagedAttentionWrapper,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor,
):
bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.shape[1],
)
sm_scale = self.scaling
q_indptr = torch.arange(0, bs + 1).to(0).int()
kv_lens = paged_kernel_lens.to(torch.int32)
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
class FlashInferMLAIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants
self.num_qo_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.v_head_dim = model_runner.model_config.v_head_dim
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.attn_backend = attn_backend
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
def update(
self,
req_pool_indices: torch.Tnesor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrapper_paged: Union[
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
],
use_ragged: bool,
):
if use_ragged:
paged_kernel_lens = prefix_lens
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
else:
paged_kernel_lens = seq_lens
paged_kernel_lens_sum = seq_lens_sum
self.call_begin_forward(
self.prefill_wrapper_ragged,
prefill_wrapper_paged,
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
seq_lens,
prefix_lens,
self.kv_indptr,
self.qo_indptr,
use_ragged,
)
def call_begin_forward(
self,
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
wrapper_paged: Union[
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
],
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
seq_lens: torch.Tensor,
prefix_lens: torch.Tensor,
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
):
bs = len(req_pool_indices)
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum,
dtype=torch.int32,
device=req_pool_indices.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.shape[1],
)
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
sm_scale = self.scaling
# extend part
if use_ragged:
wrapper_ragged.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=qo_indptr,
num_qo_heads=self.num_qo_heads,
num_kv_heads=self.num_kv_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type,
)
if not global_server_args_dict["disable_radix_cache"]:
# mla paged prefill
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
wrapper_paged.plan(
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
self.num_qo_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
True,
sm_scale,
self.q_data_type,
self.data_type,
)
# FIXME: Here should be some logic for prefill paged when not using radix cache?
...@@ -34,6 +34,7 @@ from sglang.srt.distributed import ( ...@@ -34,6 +34,7 @@ from sglang.srt.distributed import (
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
...@@ -113,9 +114,9 @@ class ModelRunner: ...@@ -113,9 +114,9 @@ class ModelRunner:
if self.server_args.device != "cpu": if self.server_args.device != "cpu":
if server_args.enable_flashinfer_mla: if server_args.enable_flashinfer_mla:
logger.info( logger.info(
"FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM." "MLA optimization is turned on. Use flashinfer mla backend."
) )
self.server_args.attention_backend = "flashinfer" self.server_args.attention_backend = "flashinfer_mla"
else: else:
logger.info("MLA optimization is turned on. Use triton backend.") logger.info("MLA optimization is turned on. Use triton backend.")
self.server_args.attention_backend = "triton" self.server_args.attention_backend = "triton"
...@@ -703,6 +704,8 @@ class ModelRunner: ...@@ -703,6 +704,8 @@ class ModelRunner:
self.attn_backend = TritonAttnBackend(self) self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native": elif self.server_args.attention_backend == "torch_native":
self.attn_backend = TorchNativeAttnBackend(self) self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla":
self.attn_backend = FlashInferMLAAttnBackend(self)
else: else:
raise ValueError( raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}" f"Invalid attention backend: {self.server_args.attention_backend}"
......
...@@ -510,25 +510,27 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -510,25 +510,27 @@ class DeepseekV2AttentionMLA(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
if global_server_args_dict["enable_flashinfer_mla"]:
if global_server_args_dict["disable_radix_cache"]: def no_absorb() -> bool:
if forward_batch.forward_mode.is_extend(): if global_server_args_dict["enable_flashinfer_mla"]:
return self.forward_normal(positions, hidden_states, forward_batch) # Flashinfer MLA: Only do not use absorb when prefilling/extending without radix cache
else: return (
return self.forward_absorb(positions, hidden_states, forward_batch) global_server_args_dict["disable_radix_cache"]
and forward_batch.forward_mode.is_extend()
)
else: else:
return self.forward_absorb(positions, hidden_states, forward_batch) # Triton: Use normal computation for prefill and use weight absorption for extend/decode
return (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0
)
if no_absorb():
return self.forward_normal(positions, hidden_states, forward_batch)
else: else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode return self.forward_absorb(positions, hidden_states, forward_batch)
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0
):
return self.forward_normal(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
def forward_normal( def forward_normal(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment