Unverified Commit efbc687c authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files
parent 292a867a
...@@ -293,6 +293,7 @@ class ForwardBatch: ...@@ -293,6 +293,7 @@ class ForwardBatch:
# For padding # For padding
padded_static_len: int = -1 # -1 if not padded padded_static_len: int = -1 # -1 if not padded
num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
num_token_non_padded_cpu: int = None
# For Qwen2-VL # For Qwen2-VL
mrope_positions: torch.Tensor = None mrope_positions: torch.Tensor = None
...@@ -354,6 +355,7 @@ class ForwardBatch: ...@@ -354,6 +355,7 @@ class ForwardBatch:
ret.num_token_non_padded = torch.tensor( ret.num_token_non_padded = torch.tensor(
len(batch.input_ids), dtype=torch.int32 len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)
ret.num_token_non_padded_cpu = len(batch.input_ids)
# For MLP sync # For MLP sync
if batch.global_num_tokens is not None: if batch.global_num_tokens is not None:
......
...@@ -31,7 +31,12 @@ import torch.distributed as dist ...@@ -31,7 +31,12 @@ import torch.distributed as dist
from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import (
AttentionArch,
ModelConfig,
get_nsa_index_head_dim,
is_deepseek_nsa,
)
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.distributed import ( from sglang.srt.distributed import (
...@@ -96,6 +101,7 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -96,6 +101,7 @@ from sglang.srt.mem_cache.memory_pool import (
HybridReqToTokenPool, HybridReqToTokenPool,
MHATokenToKVPool, MHATokenToKVPool,
MLATokenToKVPool, MLATokenToKVPool,
NSATokenToKVPool,
ReqToTokenPool, ReqToTokenPool,
SWAKVPool, SWAKVPool,
) )
...@@ -157,6 +163,7 @@ MLA_ATTENTION_BACKENDS = [ ...@@ -157,6 +163,7 @@ MLA_ATTENTION_BACKENDS = [
"cutlass_mla", "cutlass_mla",
"trtllm_mla", "trtllm_mla",
"ascend", "ascend",
"nsa",
] ]
...@@ -1547,6 +1554,7 @@ class ModelRunner: ...@@ -1547,6 +1554,7 @@ class ModelRunner:
assert self.is_draft_worker assert self.is_draft_worker
# Initialize token_to_kv_pool # Initialize token_to_kv_pool
is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
if self.server_args.attention_backend == "ascend": if self.server_args.attention_backend == "ascend":
if self.use_mla_backend: if self.use_mla_backend:
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool( self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
...@@ -1555,6 +1563,7 @@ class ModelRunner: ...@@ -1555,6 +1563,7 @@ class ModelRunner:
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank, kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim, qk_rope_head_dim=self.model_config.qk_rope_head_dim,
index_head_dim=self.model_config.index_head_dim,
layer_num=self.num_effective_layers, layer_num=self.num_effective_layers,
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
...@@ -1574,7 +1583,22 @@ class ModelRunner: ...@@ -1574,7 +1583,22 @@ class ModelRunner:
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
) )
elif self.use_mla_backend and is_nsa_model:
self.token_to_kv_pool = NSATokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
)
elif self.use_mla_backend: elif self.use_mla_backend:
assert not is_nsa_model
self.token_to_kv_pool = MLATokenToKVPool( self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
page_size=self.page_size, page_size=self.page_size,
......
...@@ -75,11 +75,16 @@ class NPUGraphRunner(CudaGraphRunner): ...@@ -75,11 +75,16 @@ class NPUGraphRunner(CudaGraphRunner):
self.positions[: self.raw_num_token].copy_(forward_batch.positions) self.positions[: self.raw_num_token].copy_(forward_batch.positions)
# Replay # Replay
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs) if self.model_runner.model_config.index_head_dim is None:
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
self.bs - self.raw_bs
)
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,)) thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
thread.start() thread.start()
self.graphs[self.bs].replay() self.graphs[self.bs].replay()
thread.join() thread.join()
else:
self.graphs[self.bs].replay()
output = self.output_buffers[self.bs] output = self.output_buffers[self.bs]
if isinstance(output, LogitsProcessorOutput): if isinstance(output, LogitsProcessorOutput):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# Adapted from: # Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model.""" """Inference-only DeepseekV2 model."""
from __future__ import annotations
import concurrent.futures import concurrent.futures
import logging import logging
...@@ -25,10 +26,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union ...@@ -25,10 +26,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from tqdm import tqdm
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt import single_batch_overlap from sglang.srt import single_batch_overlap
from sglang.srt.configs.model_config import (
get_nsa_index_head_dim,
get_nsa_index_n_heads,
get_nsa_index_topk,
is_deepseek_nsa,
)
from sglang.srt.debug_utils.dumper import dumper
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_pp_group, get_pp_group,
...@@ -48,6 +55,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import ( ...@@ -48,6 +55,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
NPUFusedMLAPreprocess, NPUFusedMLAPreprocess,
is_mla_preprocess_enabled, is_mla_preprocess_enabled,
) )
from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
from sglang.srt.layers.communicator import ( from sglang.srt.layers.communicator import (
LayerCommunicator, LayerCommunicator,
LayerScatterModes, LayerScatterModes,
...@@ -172,10 +180,13 @@ elif _is_hip: ...@@ -172,10 +180,13 @@ elif _is_hip:
from sglang.srt.layers.quantization.awq_triton import ( from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_triton as awq_dequantize, awq_dequantize_triton as awq_dequantize,
) )
elif _is_npu:
import custom_ops
import sgl_kernel_npu
import torch_npu
else: else:
pass pass
_is_flashinfer_available = is_flashinfer_available() _is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported() _is_sm100_supported = is_cuda() and is_sm100_supported()
...@@ -184,6 +195,7 @@ logger = logging.getLogger(__name__) ...@@ -184,6 +195,7 @@ logger = logging.getLogger(__name__)
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [ FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
"fa3", "fa3",
"nsa",
"flashinfer", "flashinfer",
"cutlass_mla", "cutlass_mla",
"trtllm_mla", "trtllm_mla",
...@@ -204,6 +216,9 @@ class AttnForwardMethod(IntEnum): ...@@ -204,6 +216,9 @@ class AttnForwardMethod(IntEnum):
# Use absorbed multi-latent attention # Use absorbed multi-latent attention
MLA = auto() MLA = auto()
# Use Deepseek V3.2 sparse multi-latent attention
NPU_MLA_SPARSE = auto()
# Use multi-head attention, but with KV cache chunked. # Use multi-head attention, but with KV cache chunked.
# This method can avoid OOM when prefix lengths are long. # This method can avoid OOM when prefix lengths are long.
MHA_CHUNKED_KV = auto() MHA_CHUNKED_KV = auto()
...@@ -246,7 +261,13 @@ def handle_attention_ascend(attn, forward_batch): ...@@ -246,7 +261,13 @@ def handle_attention_ascend(attn, forward_batch):
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend()
): ):
if hasattr(attn, "indexer"):
return AttnForwardMethod.NPU_MLA_SPARSE
else:
return AttnForwardMethod.MHA return AttnForwardMethod.MHA
else:
if hasattr(attn, "indexer"):
return AttnForwardMethod.NPU_MLA_SPARSE
else: else:
return AttnForwardMethod.MLA return AttnForwardMethod.MLA
...@@ -267,7 +288,9 @@ def _is_extend_without_speculative(forward_batch): ...@@ -267,7 +288,9 @@ def _is_extend_without_speculative(forward_batch):
) )
def _handle_attention_backend(attn, forward_batch, backend_name): def _handle_attention_backend(
attn: DeepseekV2AttentionMLA, forward_batch, backend_name
):
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch) sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
disable_ragged = ( disable_ragged = (
backend_name in ["flashinfer", "flashmla"] backend_name in ["flashinfer", "flashmla"]
...@@ -333,6 +356,10 @@ def handle_attention_aiter(attn, forward_batch): ...@@ -333,6 +356,10 @@ def handle_attention_aiter(attn, forward_batch):
return AttnForwardMethod.MLA return AttnForwardMethod.MLA
def handle_attention_nsa(attn, forward_batch):
return AttnForwardMethod.MLA
def handle_attention_triton(attn, forward_batch): def handle_attention_triton(attn, forward_batch):
if ( if (
_is_extend_without_speculative(forward_batch) _is_extend_without_speculative(forward_batch)
...@@ -1005,6 +1032,10 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1005,6 +1032,10 @@ class DeepseekV2AttentionMLA(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
# NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"
# For tensor parallel attention # For tensor parallel attention
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
self.fused_qkv_a_proj_with_mqa = ReplicatedLinear( self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
...@@ -1042,6 +1073,26 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1042,6 +1073,26 @@ class DeepseekV2AttentionMLA(nn.Module):
prefix=add_prefix("kv_a_proj_with_mqa", prefix), prefix=add_prefix("kv_a_proj_with_mqa", prefix),
) )
self.use_nsa = is_deepseek_nsa(config)
if self.use_nsa:
self.indexer = Indexer(
hidden_size=hidden_size,
index_n_heads=get_nsa_index_n_heads(config),
index_head_dim=get_nsa_index_head_dim(config),
rope_head_dim=qk_rope_head_dim,
index_topk=get_nsa_index_topk(config),
q_lora_rank=q_lora_rank,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
scale_fmt="ue8m0",
block_size=128,
rope_scaling=rope_scaling,
prefix=add_prefix("indexer", prefix),
quant_config=quant_config,
layer_id=layer_id,
alt_stream=alt_stream,
)
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank, self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
...@@ -1064,9 +1115,6 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1064,9 +1115,6 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope_wrapper( self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim, qk_rope_head_dim,
rotary_dim=qk_rope_head_dim, rotary_dim=qk_rope_head_dim,
...@@ -1193,8 +1241,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1193,8 +1241,8 @@ class DeepseekV2AttentionMLA(nn.Module):
self.is_mla_preprocess_enabled = is_mla_preprocess_enabled() self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
if self.is_mla_preprocess_enabled: if self.is_mla_preprocess_enabled:
assert ( assert (
quant_config.get_name() == "w8a8_int8" quant_config is None or quant_config.get_name() == "w8a8_int8"
), "MLA Preprocess only works with W8A8Int8" ), "MLA Preprocess only works with Unquant or W8A8Int8"
self.mla_preprocess = None self.mla_preprocess = None
def dispatch_attn_forward_method( def dispatch_attn_forward_method(
...@@ -1272,7 +1320,6 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1272,7 +1320,6 @@ class DeepseekV2AttentionMLA(nn.Module):
return hidden_states, None, forward_batch, None return hidden_states, None, forward_batch, None
attn_forward_method = self.dispatch_attn_forward_method(forward_batch) attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
if attn_forward_method == AttnForwardMethod.MHA: if attn_forward_method == AttnForwardMethod.MHA:
inner_state = self.forward_normal_prepare( inner_state = self.forward_normal_prepare(
positions, hidden_states, forward_batch, zero_allocator positions, hidden_states, forward_batch, zero_allocator
...@@ -1304,6 +1351,10 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1304,6 +1351,10 @@ class DeepseekV2AttentionMLA(nn.Module):
inner_state = self.mla_preprocess.forward( inner_state = self.mla_preprocess.forward(
positions, hidden_states, forward_batch, zero_allocator positions, hidden_states, forward_batch, zero_allocator
) )
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
inner_state = self.forward_npu_sparse_prepare(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE: elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
inner_state = self.forward_absorb_fused_mla_rope_prepare( inner_state = self.forward_absorb_fused_mla_rope_prepare(
positions, hidden_states, forward_batch, zero_allocator positions, hidden_states, forward_batch, zero_allocator
...@@ -1329,6 +1380,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1329,6 +1380,8 @@ class DeepseekV2AttentionMLA(nn.Module):
return self.forward_normal_chunked_kv_core(*inner_state) return self.forward_normal_chunked_kv_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA: elif attn_forward_method == AttnForwardMethod.MLA:
return self.forward_absorb_core(*inner_state) return self.forward_absorb_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
return self.forward_npu_sparse_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE: elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
return self.forward_absorb_fused_mla_rope_core(*inner_state) return self.forward_absorb_fused_mla_rope_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU: elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
...@@ -1424,6 +1477,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1424,6 +1477,7 @@ class DeepseekV2AttentionMLA(nn.Module):
): ):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
q_lora = None
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
if ( if (
(not isinstance(hidden_states, tuple)) (not isinstance(hidden_states, tuple))
...@@ -1462,6 +1516,10 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1462,6 +1516,10 @@ class DeepseekV2AttentionMLA(nn.Module):
q = self.q_a_layernorm(q) q = self.q_a_layernorm(q)
k_nope = self.kv_a_layernorm(k_nope) k_nope = self.kv_a_layernorm(k_nope)
# q_lora needed by indexer
if self.use_nsa:
q_lora = q
k_nope = k_nope.unsqueeze(1) k_nope = k_nope.unsqueeze(1)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else: else:
...@@ -1527,14 +1585,41 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1527,14 +1585,41 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out.transpose(0, 1) q_nope_out = q_nope_out.transpose(0, 1)
if not self._fuse_rope_for_trtllm_mla(forward_batch) and ( if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
not _use_aiter or not _is_gfx95_supported not _use_aiter or not _is_gfx95_supported or self.use_nsa
): ):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions topk_indices = None
if q_lora is not None:
topk_indices = self.indexer(
x=hidden_states,
q_lora=q_lora,
positions=positions,
forward_batch=forward_batch,
layer_id=self.layer_id,
)
return (
q_pe,
k_pe,
q_nope_out,
k_nope,
forward_batch,
zero_allocator,
positions,
topk_indices,
)
def forward_absorb_core( def forward_absorb_core(
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions self,
q_pe,
k_pe,
q_nope_out,
k_nope,
forward_batch,
zero_allocator,
positions,
topk_indices,
): ):
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS: if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
extra_args = {} extra_args = {}
...@@ -1543,6 +1628,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1543,6 +1628,7 @@ class DeepseekV2AttentionMLA(nn.Module):
"cos_sin_cache": self.rotary_emb.cos_sin_cache, "cos_sin_cache": self.rotary_emb.cos_sin_cache,
"is_neox": self.rotary_emb.is_neox_style, "is_neox": self.rotary_emb.is_neox_style,
} }
attn_output = self.attn_mqa( attn_output = self.attn_mqa(
q_nope_out, q_nope_out,
k_nope, k_nope,
...@@ -1551,6 +1637,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1551,6 +1637,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_rope=q_pe, q_rope=q_pe,
k_rope=k_pe, k_rope=k_pe,
**extra_args, **extra_args,
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
) )
else: else:
if _use_aiter_gfx95: if _use_aiter_gfx95:
...@@ -1570,7 +1657,13 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1570,7 +1657,13 @@ class DeepseekV2AttentionMLA(nn.Module):
q = torch.cat([q_nope_out, q_pe], dim=-1) q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1) k = torch.cat([k_nope, k_pe], dim=-1)
attn_output = self.attn_mqa(q, k, k_nope, forward_batch) attn_output = self.attn_mqa(
q,
k,
k_nope,
forward_batch,
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.use_deep_gemm_bmm: if self.use_deep_gemm_bmm:
...@@ -1652,6 +1745,221 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1652,6 +1745,221 @@ class DeepseekV2AttentionMLA(nn.Module):
return output return output
def forward_npu_sparse_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
):
"""
Reuse `self.q_lora_rank is not None` branch from forward_absorb_prepare
"""
if self.is_mla_preprocess_enabled and forward_batch.forward_mode.is_decode():
if self.mla_preprocess is None:
self.mla_preprocess = NPUFusedMLAPreprocess(
self.fused_qkv_a_proj_with_mqa,
self.q_a_layernorm,
self.kv_a_layernorm,
self.q_b_proj,
self.w_kc,
self.rotary_emb,
self.layer_id,
self.num_local_heads,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
)
(
q_pe,
k_pe,
q_nope_out,
k_nope,
forward_batch,
zero_allocator,
positions,
) = self.mla_preprocess.forward(
positions, hidden_states, forward_batch, zero_allocator
)
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
q, _ = fused_qkv_a_proj_out.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q_lora = self.q_a_layernorm(q)
else:
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
if (
(not isinstance(hidden_states, tuple))
and hidden_states.shape[0] <= 16
and self.use_min_latency_fused_a_gemm
):
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
)
else:
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
q, latent_cache = fused_qkv_a_proj_out.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
k_nope = latent_cache[..., : self.kv_lora_rank]
# overlap qk norm
if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q = self.q_a_layernorm(q)
with torch.cuda.stream(self.alt_stream):
k_nope = self.kv_a_layernorm(k_nope)
current_stream.wait_stream(self.alt_stream)
else:
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
q, k_nope = fused_rms_mxfp4_quant(
q,
self.q_a_layernorm.weight,
self.q_a_layernorm.variance_epsilon,
k_nope,
self.kv_a_layernorm.weight,
self.kv_a_layernorm.variance_epsilon,
)
else:
q = self.q_a_layernorm(q)
k_nope = self.kv_a_layernorm(k_nope)
q_lora = q.clone() # required for topk_indices
k_nope = k_nope.unsqueeze(1)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
q_nope, q_pe = q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
if self.use_deep_gemm_bmm:
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
per_token_group_quant_mla_deep_gemm_masked_fp8(
q_nope.transpose(0, 1)
)
)
q_nope_out = q_nope.new_empty(
(self.num_local_heads, aligned_m, self.kv_lora_rank)
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(q_nope_val, q_nope_scale),
(self.w_kc, self.w_scale_k),
q_nope_out,
masked_m,
expected_m,
)
q_nope_out = q_nope_out[:, :expected_m, :]
elif _is_hip:
# TODO(haishaw): add bmm_fp8 to ROCm
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
x = q_nope.transpose(0, 1)
q_nope_out = torch.empty(
x.shape[0],
x.shape[1],
self.w_kc.shape[2],
device=x.device,
dtype=torch.bfloat16,
)
batched_gemm_afp4wfp4_pre_quant(
x,
self.w_kc.transpose(-2, -1),
self.w_scale_k.transpose(-2, -1),
torch.bfloat16,
q_nope_out,
)
else:
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
self.w_kc.to(torch.bfloat16) * self.w_scale,
)
elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1),
zero_allocator.allocate(1),
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
)
else:
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_nope_out = q_nope_out.transpose(0, 1)
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
not _use_aiter or not _is_gfx95_supported
):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
# TODO: multi-stream indexer
topk_indices = self.indexer(
hidden_states, q_lora, positions, forward_batch, self.layer_id
)
return (
q_pe,
k_pe,
q_nope_out,
k_nope,
topk_indices,
forward_batch,
zero_allocator,
positions,
)
def forward_npu_sparse_core(
self,
q_pe,
k_pe,
q_nope_out,
k_nope,
topk_indices,
forward_batch,
zero_allocator,
positions,
):
attn_output = self.attn_mqa(
q_nope_out.contiguous(),
k_nope.contiguous(),
k_nope.contiguous(),
forward_batch,
save_kv_cache=True, # False if forward_batch.forward_mode.is_extend() else True,
q_rope=q_pe.contiguous(),
k_rope=k_pe.contiguous(),
topk_indices=topk_indices,
)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
attn_bmm_output = torch.empty(
(attn_output.shape[0], self.num_local_heads, self.v_head_dim),
dtype=attn_output.dtype,
device=attn_output.device,
)
if not forward_batch.forward_mode.is_decode():
attn_output = attn_output.transpose(0, 1)
torch.bmm(
attn_output,
self.w_vc,
out=attn_bmm_output.view(
-1, self.num_local_heads, self.v_head_dim
).transpose(0, 1),
)
else:
attn_output = attn_output.contiguous()
torch.ops.npu.batch_matmul_transpose(
attn_output, self.w_vc, attn_bmm_output
)
attn_bmm_output = attn_bmm_output.reshape(
-1, self.num_local_heads * self.v_head_dim
)
output, _ = self.o_proj(attn_bmm_output)
return output
def forward_absorb_fused_mla_rope_prepare( def forward_absorb_fused_mla_rope_prepare(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -2134,7 +2442,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -2134,7 +2442,6 @@ class DeepseekV2DecoderLayer(nn.Module):
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
gemm_output_zero_allocator: BumpAllocator = None, gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor: ) -> torch.Tensor:
quant_format = ( quant_format = (
"mxfp4" "mxfp4"
if _is_gfx95_supported if _is_gfx95_supported
...@@ -3099,6 +3406,7 @@ AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla) ...@@ -3099,6 +3406,7 @@ AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
AttentionBackendRegistry.register("fa4", handle_attention_fa4) AttentionBackendRegistry.register("fa4", handle_attention_fa4)
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla) AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
AttentionBackendRegistry.register("aiter", handle_attention_aiter) AttentionBackendRegistry.register("aiter", handle_attention_aiter)
AttentionBackendRegistry.register("nsa", handle_attention_nsa)
AttentionBackendRegistry.register("triton", handle_attention_triton) AttentionBackendRegistry.register("triton", handle_attention_triton)
...@@ -3106,4 +3414,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): ...@@ -3106,4 +3414,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass pass
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM] class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):
pass
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM]
...@@ -91,6 +91,7 @@ ATTENTION_BACKEND_CHOICES = [ ...@@ -91,6 +91,7 @@ ATTENTION_BACKEND_CHOICES = [
"triton", "triton",
"torch_native", "torch_native",
"flex_attention", "flex_attention",
"nsa",
# NVIDIA specific # NVIDIA specific
"cutlass_mla", "cutlass_mla",
"fa3", "fa3",
...@@ -116,6 +117,8 @@ GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] ...@@ -116,6 +117,8 @@ GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"] DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"]
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"] RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
...@@ -284,6 +287,8 @@ class ServerArgs: ...@@ -284,6 +287,8 @@ class ServerArgs:
sampling_backend: Optional[str] = None sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = None grammar_backend: Optional[str] = None
mm_attention_backend: Optional[str] = None mm_attention_backend: Optional[str] = None
nsa_prefill: str = "flashmla_prefill"
nsa_decode: str = "fa3"
# Speculative decoding # Speculative decoding
speculative_algorithm: Optional[str] = None speculative_algorithm: Optional[str] = None
...@@ -719,6 +724,8 @@ class ServerArgs: ...@@ -719,6 +724,8 @@ class ServerArgs:
self.sampling_backend = "pytorch" self.sampling_backend = "pytorch"
def _handle_model_specific_adjustments(self): def _handle_model_specific_adjustments(self):
from sglang.srt.configs.model_config import is_deepseek_nsa
if parse_connector_type(self.model_path) == ConnectorType.INSTANCE: if parse_connector_type(self.model_path) == ConnectorType.INSTANCE:
return return
...@@ -796,6 +803,48 @@ class ServerArgs: ...@@ -796,6 +803,48 @@ class ServerArgs:
) )
self.disable_hybrid_swa_memory = True self.disable_hybrid_swa_memory = True
if is_deepseek_nsa(hf_config):
if (
self.attention_backend is None
and self.prefill_attention_backend is None
and self.decode_attention_backend is None
):
self.attention_backend = "nsa"
logger.warning("Set nsa attention backend for DeepSeek NSA.")
if not is_npu():
self.enable_dp_attention = True
self.dp_size = self.tp_size
logger.warning("DP attention is enabled for DeepSeek NSA.")
self.page_size = 64
logger.warning("Setting page size to 64 for DeepSeek NSA.")
self.mem_fraction_static = 0.8
logger.warning("Setting mem fraction static to 0.8 for DeepSeek NSA.")
# For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently
import torch
major, _ = torch.cuda.get_device_capability()
if major >= 10:
self.kv_cache_dtype = "fp8_e4m3"
logger.warning("Setting KV cache dtype to fp8.")
if self.kv_cache_dtype == "fp8_e4m3":
self.nsa_prefill = "flashmla_decode"
self.nsa_decode = "flashmla_decode"
logger.warning(
"Setting NSA backend to flashmla_decode for FP8 KV Cache."
)
# Logging env vars for NSA
from sglang.srt.layers.attention.nsa.utils import (
print_nsa_bool_env_vars,
)
print_nsa_bool_env_vars()
def _handle_sampling_backend(self): def _handle_sampling_backend(self):
if self.sampling_backend is None: if self.sampling_backend is None:
self.sampling_backend = ( self.sampling_backend = (
...@@ -1023,6 +1072,7 @@ class ServerArgs: ...@@ -1023,6 +1072,7 @@ class ServerArgs:
model_arch = self.get_hf_config().architectures[0] model_arch = self.get_hf_config().architectures[0]
if model_arch in [ if model_arch in [
"DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLM", "DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM", "Glm4MoeForCausalLM",
"BailingMoeForCausalLM", "BailingMoeForCausalLM",
...@@ -1974,6 +2024,18 @@ class ServerArgs: ...@@ -1974,6 +2024,18 @@ class ServerArgs:
default=ServerArgs.mm_attention_backend, default=ServerArgs.mm_attention_backend,
help="Set multimodal attention backend.", help="Set multimodal attention backend.",
) )
parser.add_argument(
"--nsa-prefill",
default=ServerArgs.nsa_prefill,
type=str,
choices=NSA_CHOICES,
)
parser.add_argument(
"--nsa-decode",
default=ServerArgs.nsa_decode,
type=str,
choices=NSA_CHOICES,
)
# Speculative decoding # Speculative decoding
parser.add_argument( parser.add_argument(
...@@ -3251,6 +3313,7 @@ def auto_choose_speculative_params(self: ServerArgs): ...@@ -3251,6 +3313,7 @@ def auto_choose_speculative_params(self: ServerArgs):
# The default value for llama # The default value for llama
return (5, 4, 8) return (5, 4, 8)
elif arch in [ elif arch in [
"DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLM", "DeepseekV3ForCausalLM",
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"GptOssForCausalLM", "GptOssForCausalLM",
......
...@@ -705,6 +705,8 @@ class TboForwardBatchPreparer: ...@@ -705,6 +705,8 @@ class TboForwardBatchPreparer:
extend_num_tokens=extend_num_tokens, extend_num_tokens=extend_num_tokens,
attn_backend=output_attn_backend, attn_backend=output_attn_backend,
num_token_non_padded=out_num_token_non_padded, num_token_non_padded=out_num_token_non_padded,
# TODO: handle it when we need TBO + DeepSeek V3.2
num_token_non_padded_cpu=None,
tbo_split_seq_index=None, tbo_split_seq_index=None,
tbo_parent_token_range=(start_token_index, end_token_index), tbo_parent_token_range=(start_token_index, end_token_index),
tbo_children=None, tbo_children=None,
......
...@@ -471,7 +471,7 @@ def is_pin_memory_available() -> bool: ...@@ -471,7 +471,7 @@ def is_pin_memory_available() -> bool:
class LayerFn(Protocol): class LayerFn(Protocol):
def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ... def __call__(self, idx: int, prefix: str) -> torch.nn.Module: ...
def make_layers( def make_layers(
...@@ -482,7 +482,7 @@ def make_layers( ...@@ -482,7 +482,7 @@ def make_layers(
prefix: str = "", prefix: str = "",
return_tuple: bool = False, return_tuple: bool = False,
offloader_kwargs: Dict[str, Any] = {}, offloader_kwargs: Dict[str, Any] = {},
) -> Tuple[int, int, torch.nn.ModuleList]: ) -> Tuple[torch.nn.Module, int, int]:
"""Make a list of layers with the given layer function""" """Make a list of layers with the given layer function"""
# circula imports # circula imports
from sglang.srt.distributed import get_pp_indices from sglang.srt.distributed import get_pp_indices
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment