Unverified Commit 20b8d230 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Cleaning indexer for DeepSeek V3.2 (#11682)

parent d1984e21
......@@ -17,7 +17,7 @@ if is_cuda():
except ImportError as e:
deep_gemm = e
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization import deep_gemm_wrapper
......@@ -168,43 +168,6 @@ class Indexer(CustomOp):
self.scale_fmt = scale_fmt
self.softmax_scale = self.head_dim**-0.5
def _forward_fake(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
):
bs = x.shape[0]
assert self.index_topk == 2048
ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[
None, ...
].repeat(bs, 1)
if forward_batch.forward_mode.is_extend():
assert (
forward_batch.extend_seq_lens_cpu is not None
and forward_batch.seq_lens_cpu is not None
)
which = 0
for i, (kv_len, qo_len) in enumerate(
zip(
forward_batch.seq_lens_cpu.tolist(),
forward_batch.extend_seq_lens_cpu,
strict=True,
)
):
for j in range(kv_len - qo_len, kv_len):
ans[which, j + 1 :] = -1
which += 1
assert which == ans.shape[0]
else:
assert forward_batch.seq_lens_cpu is not None
for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()):
ans[i, seq_len:] = -1
return ans
@torch.compile(dynamic=True)
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
weights, _ = self.weights_proj(x)
......@@ -404,7 +367,7 @@ class Indexer(CustomOp):
return topk_result
def forward_indexer_bs_1(
def forward_indexer(
self,
q_fp8: torch.Tensor,
weights: torch.Tensor,
......@@ -485,20 +448,9 @@ class Indexer(CustomOp):
q_len_start = q_len_end
topk_indices = torch.cat(topk_indices_list, dim=0)
return topk_indices
def forward_indexer(
self,
q_fp8: torch.Tensor,
weights: torch.Tensor,
forward_batch: ForwardBatch,
topk: int,
layer_id: int,
) -> Optional[torch.Tensor]:
return self.forward_indexer_bs_1(q_fp8, weights, forward_batch, topk, layer_id)
def _forward(
def forward_cuda(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
......@@ -530,9 +482,6 @@ class Indexer(CustomOp):
if metadata is None:
return None
if not NSA_USE_REAL_INDEXER: # temporary
return self._forward_fake(x, q_lora, positions, forward_batch, layer_id)
query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
if enable_dual_stream:
......@@ -588,19 +537,8 @@ class Indexer(CustomOp):
topk=self.index_topk,
layer_id=layer_id,
)
return topk_result
def forward_cuda(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
) -> Optional[torch.Tensor]:
return self._forward(x, q_lora, positions, forward_batch, layer_id)
def forward_npu(
self,
x: torch.Tensor,
......
# temp NSA debugging environ
from sglang.srt.utils import get_bool_env_var
NSA_USE_REAL_INDEXER = get_bool_env_var("SGLANG_NSA_USE_REAL_INDEXER", "true")
NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true")
NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment