Unverified Commit efbc687c authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files
parent 292a867a
......@@ -293,6 +293,7 @@ class ForwardBatch:
# For padding
padded_static_len: int = -1 # -1 if not padded
num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
num_token_non_padded_cpu: int = None
# For Qwen2-VL
mrope_positions: torch.Tensor = None
......@@ -354,6 +355,7 @@ class ForwardBatch:
ret.num_token_non_padded = torch.tensor(
len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True)
ret.num_token_non_padded_cpu = len(batch.input_ids)
# For MLP sync
if batch.global_num_tokens is not None:
......
......@@ -31,7 +31,12 @@ import torch.distributed as dist
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.configs.model_config import (
AttentionArch,
ModelConfig,
get_nsa_index_head_dim,
is_deepseek_nsa,
)
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.distributed import (
......@@ -96,6 +101,7 @@ from sglang.srt.mem_cache.memory_pool import (
HybridReqToTokenPool,
MHATokenToKVPool,
MLATokenToKVPool,
NSATokenToKVPool,
ReqToTokenPool,
SWAKVPool,
)
......@@ -157,6 +163,7 @@ MLA_ATTENTION_BACKENDS = [
"cutlass_mla",
"trtllm_mla",
"ascend",
"nsa",
]
......@@ -1547,6 +1554,7 @@ class ModelRunner:
assert self.is_draft_worker
# Initialize token_to_kv_pool
is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
if self.server_args.attention_backend == "ascend":
if self.use_mla_backend:
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
......@@ -1555,6 +1563,7 @@ class ModelRunner:
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
index_head_dim=self.model_config.index_head_dim,
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
......@@ -1574,7 +1583,22 @@ class ModelRunner:
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
elif self.use_mla_backend and is_nsa_model:
self.token_to_kv_pool = NSATokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
)
elif self.use_mla_backend:
assert not is_nsa_model
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
......
......@@ -75,11 +75,16 @@ class NPUGraphRunner(CudaGraphRunner):
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
# Replay
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
thread.start()
self.graphs[self.bs].replay()
thread.join()
if self.model_runner.model_config.index_head_dim is None:
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
self.bs - self.raw_bs
)
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
thread.start()
self.graphs[self.bs].replay()
thread.join()
else:
self.graphs[self.bs].replay()
output = self.output_buffers[self.bs]
if isinstance(output, LogitsProcessorOutput):
......
This diff is collapsed.
......@@ -91,6 +91,7 @@ ATTENTION_BACKEND_CHOICES = [
"triton",
"torch_native",
"flex_attention",
"nsa",
# NVIDIA specific
"cutlass_mla",
"fa3",
......@@ -116,6 +117,8 @@ GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"]
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
......@@ -284,6 +287,8 @@ class ServerArgs:
sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = None
mm_attention_backend: Optional[str] = None
nsa_prefill: str = "flashmla_prefill"
nsa_decode: str = "fa3"
# Speculative decoding
speculative_algorithm: Optional[str] = None
......@@ -719,6 +724,8 @@ class ServerArgs:
self.sampling_backend = "pytorch"
def _handle_model_specific_adjustments(self):
from sglang.srt.configs.model_config import is_deepseek_nsa
if parse_connector_type(self.model_path) == ConnectorType.INSTANCE:
return
......@@ -796,6 +803,48 @@ class ServerArgs:
)
self.disable_hybrid_swa_memory = True
if is_deepseek_nsa(hf_config):
if (
self.attention_backend is None
and self.prefill_attention_backend is None
and self.decode_attention_backend is None
):
self.attention_backend = "nsa"
logger.warning("Set nsa attention backend for DeepSeek NSA.")
if not is_npu():
self.enable_dp_attention = True
self.dp_size = self.tp_size
logger.warning("DP attention is enabled for DeepSeek NSA.")
self.page_size = 64
logger.warning("Setting page size to 64 for DeepSeek NSA.")
self.mem_fraction_static = 0.8
logger.warning("Setting mem fraction static to 0.8 for DeepSeek NSA.")
# For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently
import torch
major, _ = torch.cuda.get_device_capability()
if major >= 10:
self.kv_cache_dtype = "fp8_e4m3"
logger.warning("Setting KV cache dtype to fp8.")
if self.kv_cache_dtype == "fp8_e4m3":
self.nsa_prefill = "flashmla_decode"
self.nsa_decode = "flashmla_decode"
logger.warning(
"Setting NSA backend to flashmla_decode for FP8 KV Cache."
)
# Logging env vars for NSA
from sglang.srt.layers.attention.nsa.utils import (
print_nsa_bool_env_vars,
)
print_nsa_bool_env_vars()
def _handle_sampling_backend(self):
if self.sampling_backend is None:
self.sampling_backend = (
......@@ -1023,6 +1072,7 @@ class ServerArgs:
model_arch = self.get_hf_config().architectures[0]
if model_arch in [
"DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
"BailingMoeForCausalLM",
......@@ -1974,6 +2024,18 @@ class ServerArgs:
default=ServerArgs.mm_attention_backend,
help="Set multimodal attention backend.",
)
parser.add_argument(
"--nsa-prefill",
default=ServerArgs.nsa_prefill,
type=str,
choices=NSA_CHOICES,
)
parser.add_argument(
"--nsa-decode",
default=ServerArgs.nsa_decode,
type=str,
choices=NSA_CHOICES,
)
# Speculative decoding
parser.add_argument(
......@@ -3251,6 +3313,7 @@ def auto_choose_speculative_params(self: ServerArgs):
# The default value for llama
return (5, 4, 8)
elif arch in [
"DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV2ForCausalLM",
"GptOssForCausalLM",
......
......@@ -705,6 +705,8 @@ class TboForwardBatchPreparer:
extend_num_tokens=extend_num_tokens,
attn_backend=output_attn_backend,
num_token_non_padded=out_num_token_non_padded,
# TODO: handle it when we need TBO + DeepSeek V3.2
num_token_non_padded_cpu=None,
tbo_split_seq_index=None,
tbo_parent_token_range=(start_token_index, end_token_index),
tbo_children=None,
......
......@@ -471,7 +471,7 @@ def is_pin_memory_available() -> bool:
class LayerFn(Protocol):
def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
def __call__(self, idx: int, prefix: str) -> torch.nn.Module: ...
def make_layers(
......@@ -482,7 +482,7 @@ def make_layers(
prefix: str = "",
return_tuple: bool = False,
offloader_kwargs: Dict[str, Any] = {},
) -> Tuple[int, int, torch.nn.ModuleList]:
) -> Tuple[torch.nn.Module, int, int]:
"""Make a list of layers with the given layer function"""
# circula imports
from sglang.srt.distributed import get_pp_indices
......
......@@ -123,6 +123,38 @@ def get_hf_text_config(config: PretrainedConfig):
return config
# Temporary hack for DeepSeek-V3.2 model
def _load_deepseek_v32_model(
model_path: str,
trust_remote_code: bool = False,
revision: Optional[str] = None,
**kwargs,
):
# first get the local path
local_path = download_from_hf(model_path)
# then load the config file in json
config_file = os.path.join(local_path, "config.json")
if not os.path.exists(config_file):
raise RuntimeError(f"Can't find config file in {local_path}.")
with open(config_file, "r") as f:
config_json = json.load(f)
config_json["architectures"] = ["DeepseekV3ForCausalLM"]
config_json["model_type"] = "deepseek_v3"
tmp_path = os.path.join(local_path, "_tmp_config_folder")
os.makedirs(tmp_path, exist_ok=True)
unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}")
with open(unique_path, "w") as f:
json.dump(config_json, f)
return AutoConfig.from_pretrained(
unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
@lru_cache_frozenset(maxsize=32)
def get_config(
model: str,
......@@ -144,9 +176,17 @@ def get_config(
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
model = client.get_local_dir()
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
except ValueError as e:
if not "deepseek_v32" in str(e):
raise e
config = _load_deepseek_v32_model(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
if (
config.architectures is not None
and config.architectures[0] == "Phi4MMForCausalLM"
......
import torch
import torch.nn as nn
class DummyModel(nn.Module):
def __init__(self, d_in=2048, n_heads=128, softmax_scale=0.5):
super().__init__()
self.weights_proj = nn.Linear(d_in, 1024)
self.n_heads = n_heads
self.softmax_scale = softmax_scale
def _get_logits_head_gate_orig(self, x: torch.Tensor, q_scale: torch.Tensor):
weights = self.weights_proj(x)
weights = weights * self.n_heads**-0.5
q_scale = q_scale.unsqueeze(1) # (B,1,1)
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
return weights
def _get_logits_head_gate_opt(self, x: torch.Tensor, q_scale: torch.Tensor):
weights = self.weights_proj(x)
q_scale = q_scale.unsqueeze(1) # (B,1,1)
scale_const = self.n_heads**-0.5 * q_scale * self.softmax_scale # (B,1,1)
weights = weights.unsqueeze(-1) * scale_const # (B,1024,1)
return weights
def main():
torch.manual_seed(0)
model = DummyModel(d_in=2048, n_heads=128, softmax_scale=0.5)
x = torch.randn(128, 2048) # batch=128, d_in=2048
q_scale = torch.randn(128, 1)
import time
start = time.time()
for _ in range(1000):
out_orig = model._get_logits_head_gate_orig(x, q_scale)
print("Original version time:", time.time() - start)
start = time.time()
for _ in range(1000):
out_opt = model._get_logits_head_gate_opt(x, q_scale)
print("Optimized version time:", time.time() - start)
print("Difference:", (out_orig - out_opt).abs().max().item())
assert torch.allclose(out_orig, out_opt), "Mismatch between original and optimized"
if __name__ == "__main__":
main()
"""
Original version time: 0.49235057830810547
Optimized version time: 0.4087331295013428
Difference: 1.4901161193847656e-08
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment