Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
20b8d230
Unverified
Commit
20b8d230
authored
Oct 17, 2025
by
Baizhou Zhang
Committed by
GitHub
Oct 17, 2025
Browse files
Cleaning indexer for DeepSeek V3.2 (#11682)
parent
d1984e21
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
66 deletions
+3
-66
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
+3
-65
python/sglang/srt/layers/attention/nsa/utils.py
python/sglang/srt/layers/attention/nsa/utils.py
+0
-1
No files found.
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
View file @
20b8d230
...
...
@@ -17,7 +17,7 @@ if is_cuda():
except
ImportError
as
e
:
deep_gemm
=
e
from
sglang.srt.layers.attention.nsa.utils
import
NSA_DUAL_STREAM
,
NSA_USE_REAL_INDEXER
from
sglang.srt.layers.attention.nsa.utils
import
NSA_DUAL_STREAM
from
sglang.srt.layers.dp_attention
import
get_attention_tp_group
from
sglang.srt.layers.linear
import
ReplicatedLinear
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
...
...
@@ -168,43 +168,6 @@ class Indexer(CustomOp):
self
.
scale_fmt
=
scale_fmt
self
.
softmax_scale
=
self
.
head_dim
**-
0.5
def
_forward_fake
(
self
,
x
:
torch
.
Tensor
,
q_lora
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
):
bs
=
x
.
shape
[
0
]
assert
self
.
index_topk
==
2048
ans
=
torch
.
arange
(
0
,
self
.
index_topk
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)[
None
,
...
].
repeat
(
bs
,
1
)
if
forward_batch
.
forward_mode
.
is_extend
():
assert
(
forward_batch
.
extend_seq_lens_cpu
is
not
None
and
forward_batch
.
seq_lens_cpu
is
not
None
)
which
=
0
for
i
,
(
kv_len
,
qo_len
)
in
enumerate
(
zip
(
forward_batch
.
seq_lens_cpu
.
tolist
(),
forward_batch
.
extend_seq_lens_cpu
,
strict
=
True
,
)
):
for
j
in
range
(
kv_len
-
qo_len
,
kv_len
):
ans
[
which
,
j
+
1
:]
=
-
1
which
+=
1
assert
which
==
ans
.
shape
[
0
]
else
:
assert
forward_batch
.
seq_lens_cpu
is
not
None
for
i
,
seq_len
in
enumerate
(
forward_batch
.
seq_lens_cpu
.
tolist
()):
ans
[
i
,
seq_len
:]
=
-
1
return
ans
@
torch
.
compile
(
dynamic
=
True
)
def
_get_logits_head_gate
(
self
,
x
:
torch
.
Tensor
,
q_scale
:
torch
.
Tensor
):
weights
,
_
=
self
.
weights_proj
(
x
)
...
...
@@ -404,7 +367,7 @@ class Indexer(CustomOp):
return
topk_result
def
forward_indexer
_bs_1
(
def
forward_indexer
(
self
,
q_fp8
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
...
...
@@ -485,20 +448,9 @@ class Indexer(CustomOp):
q_len_start
=
q_len_end
topk_indices
=
torch
.
cat
(
topk_indices_list
,
dim
=
0
)
return
topk_indices
def
forward_indexer
(
self
,
q_fp8
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
topk
:
int
,
layer_id
:
int
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
forward_indexer_bs_1
(
q_fp8
,
weights
,
forward_batch
,
topk
,
layer_id
)
def
_forward
(
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
q_lora
:
torch
.
Tensor
,
...
...
@@ -530,9 +482,6 @@ class Indexer(CustomOp):
if
metadata
is
None
:
return
None
if
not
NSA_USE_REAL_INDEXER
:
# temporary
return
self
.
_forward_fake
(
x
,
q_lora
,
positions
,
forward_batch
,
layer_id
)
query
,
key
=
self
.
_get_q_k_bf16
(
q_lora
,
x
,
positions
,
enable_dual_stream
)
if
enable_dual_stream
:
...
...
@@ -588,19 +537,8 @@ class Indexer(CustomOp):
topk
=
self
.
index_topk
,
layer_id
=
layer_id
,
)
return
topk_result
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
q_lora
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
_forward
(
x
,
q_lora
,
positions
,
forward_batch
,
layer_id
)
def
forward_npu
(
self
,
x
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/attention/nsa/utils.py
View file @
20b8d230
# temp NSA debugging environ
from
sglang.srt.utils
import
get_bool_env_var
NSA_USE_REAL_INDEXER
=
get_bool_env_var
(
"SGLANG_NSA_USE_REAL_INDEXER"
,
"true"
)
NSA_DUAL_STREAM
=
get_bool_env_var
(
"SGLANG_NSA_DUAL_STREAM"
,
"true"
)
NSA_FUSE_TOPK
=
get_bool_env_var
(
"SGLANG_NSA_FUSE_TOPK"
,
"true"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment