Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
20b8d230
"vscode:/vscode.git/clone" did not exist on "3ca5f69ce889c4ba16086fbcfb388c4c940aa421"
Unverified
Commit
20b8d230
authored
Oct 17, 2025
by
Baizhou Zhang
Committed by
GitHub
Oct 17, 2025
Browse files
Cleaning indexer for DeepSeek V3.2 (#11682)
parent
d1984e21
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
66 deletions
+3
-66
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
+3
-65
python/sglang/srt/layers/attention/nsa/utils.py
python/sglang/srt/layers/attention/nsa/utils.py
+0
-1
No files found.
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
View file @
20b8d230
...
@@ -17,7 +17,7 @@ if is_cuda():
...
@@ -17,7 +17,7 @@ if is_cuda():
except
ImportError
as
e
:
except
ImportError
as
e
:
deep_gemm
=
e
deep_gemm
=
e
from
sglang.srt.layers.attention.nsa.utils
import
NSA_DUAL_STREAM
,
NSA_USE_REAL_INDEXER
from
sglang.srt.layers.attention.nsa.utils
import
NSA_DUAL_STREAM
from
sglang.srt.layers.dp_attention
import
get_attention_tp_group
from
sglang.srt.layers.dp_attention
import
get_attention_tp_group
from
sglang.srt.layers.linear
import
ReplicatedLinear
from
sglang.srt.layers.linear
import
ReplicatedLinear
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
...
@@ -168,43 +168,6 @@ class Indexer(CustomOp):
...
@@ -168,43 +168,6 @@ class Indexer(CustomOp):
self
.
scale_fmt
=
scale_fmt
self
.
scale_fmt
=
scale_fmt
self
.
softmax_scale
=
self
.
head_dim
**-
0.5
self
.
softmax_scale
=
self
.
head_dim
**-
0.5
def
_forward_fake
(
self
,
x
:
torch
.
Tensor
,
q_lora
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
):
bs
=
x
.
shape
[
0
]
assert
self
.
index_topk
==
2048
ans
=
torch
.
arange
(
0
,
self
.
index_topk
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)[
None
,
...
].
repeat
(
bs
,
1
)
if
forward_batch
.
forward_mode
.
is_extend
():
assert
(
forward_batch
.
extend_seq_lens_cpu
is
not
None
and
forward_batch
.
seq_lens_cpu
is
not
None
)
which
=
0
for
i
,
(
kv_len
,
qo_len
)
in
enumerate
(
zip
(
forward_batch
.
seq_lens_cpu
.
tolist
(),
forward_batch
.
extend_seq_lens_cpu
,
strict
=
True
,
)
):
for
j
in
range
(
kv_len
-
qo_len
,
kv_len
):
ans
[
which
,
j
+
1
:]
=
-
1
which
+=
1
assert
which
==
ans
.
shape
[
0
]
else
:
assert
forward_batch
.
seq_lens_cpu
is
not
None
for
i
,
seq_len
in
enumerate
(
forward_batch
.
seq_lens_cpu
.
tolist
()):
ans
[
i
,
seq_len
:]
=
-
1
return
ans
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
)
def
_get_logits_head_gate
(
self
,
x
:
torch
.
Tensor
,
q_scale
:
torch
.
Tensor
):
def
_get_logits_head_gate
(
self
,
x
:
torch
.
Tensor
,
q_scale
:
torch
.
Tensor
):
weights
,
_
=
self
.
weights_proj
(
x
)
weights
,
_
=
self
.
weights_proj
(
x
)
...
@@ -404,7 +367,7 @@ class Indexer(CustomOp):
...
@@ -404,7 +367,7 @@ class Indexer(CustomOp):
return
topk_result
return
topk_result
def
forward_indexer
_bs_1
(
def
forward_indexer
(
self
,
self
,
q_fp8
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
...
@@ -485,20 +448,9 @@ class Indexer(CustomOp):
...
@@ -485,20 +448,9 @@ class Indexer(CustomOp):
q_len_start
=
q_len_end
q_len_start
=
q_len_end
topk_indices
=
torch
.
cat
(
topk_indices_list
,
dim
=
0
)
topk_indices
=
torch
.
cat
(
topk_indices_list
,
dim
=
0
)
return
topk_indices
return
topk_indices
def
forward_indexer
(
def
forward_cuda
(
self
,
q_fp8
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
topk
:
int
,
layer_id
:
int
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
forward_indexer_bs_1
(
q_fp8
,
weights
,
forward_batch
,
topk
,
layer_id
)
def
_forward
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
q_lora
:
torch
.
Tensor
,
q_lora
:
torch
.
Tensor
,
...
@@ -530,9 +482,6 @@ class Indexer(CustomOp):
...
@@ -530,9 +482,6 @@ class Indexer(CustomOp):
if
metadata
is
None
:
if
metadata
is
None
:
return
None
return
None
if
not
NSA_USE_REAL_INDEXER
:
# temporary
return
self
.
_forward_fake
(
x
,
q_lora
,
positions
,
forward_batch
,
layer_id
)
query
,
key
=
self
.
_get_q_k_bf16
(
q_lora
,
x
,
positions
,
enable_dual_stream
)
query
,
key
=
self
.
_get_q_k_bf16
(
q_lora
,
x
,
positions
,
enable_dual_stream
)
if
enable_dual_stream
:
if
enable_dual_stream
:
...
@@ -588,19 +537,8 @@ class Indexer(CustomOp):
...
@@ -588,19 +537,8 @@ class Indexer(CustomOp):
topk
=
self
.
index_topk
,
topk
=
self
.
index_topk
,
layer_id
=
layer_id
,
layer_id
=
layer_id
,
)
)
return
topk_result
return
topk_result
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
q_lora
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
_forward
(
x
,
q_lora
,
positions
,
forward_batch
,
layer_id
)
def
forward_npu
(
def
forward_npu
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/attention/nsa/utils.py
View file @
20b8d230
# temp NSA debugging environ
# temp NSA debugging environ
from
sglang.srt.utils
import
get_bool_env_var
from
sglang.srt.utils
import
get_bool_env_var
NSA_USE_REAL_INDEXER
=
get_bool_env_var
(
"SGLANG_NSA_USE_REAL_INDEXER"
,
"true"
)
NSA_DUAL_STREAM
=
get_bool_env_var
(
"SGLANG_NSA_DUAL_STREAM"
,
"true"
)
NSA_DUAL_STREAM
=
get_bool_env_var
(
"SGLANG_NSA_DUAL_STREAM"
,
"true"
)
NSA_FUSE_TOPK
=
get_bool_env_var
(
"SGLANG_NSA_FUSE_TOPK"
,
"true"
)
NSA_FUSE_TOPK
=
get_bool_env_var
(
"SGLANG_NSA_FUSE_TOPK"
,
"true"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment