Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cb68935c
Commit
cb68935c
authored
Mar 26, 2026
by
wanghl6
Browse files
topk opt
parent
0bd5fcd2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
24 deletions
+60
-24
vllm/envs.py
vllm/envs.py
+13
-2
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+47
-22
No files found.
vllm/envs.py
View file @
cb68935c
...
...
@@ -320,8 +320,9 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK
:
bool
=
False
VLLM_V1_USE_FA_UNIFIED_ATTN_2D
:
bool
=
False
VLLM_ENABLE_RAY_ASYNC_SCHEDULING
:
bool
=
False
USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8
:
bool
=
False
USE_LIGHTOP_TOPK
:
bool
=
False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX
:
bool
=
False
def
get_default_cache_root
():
return
os
.
getenv
(
"XDG_CACHE_HOME"
,
...
...
@@ -1990,6 +1991,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V1_USE_FA_UNIFIED_ATTN_2D"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_V1_USE_FA_UNIFIED_ATTN_2D"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
"USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8"
:
lambda
:
(
os
.
environ
.
get
(
"USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
"USE_LIGHTOP_TOPK"
:
lambda
:
(
os
.
environ
.
get
(
"USE_LIGHTOP_TOPK"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
"USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX"
:
lambda
:
(
os
.
environ
.
get
(
"USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
cb68935c
...
...
@@ -3,7 +3,7 @@
"""Custom Sparse Attention Indexer layers."""
import
torch
import
vllm.envs
as
envs
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
...
...
@@ -170,6 +170,7 @@ def sparse_attn_indexer(
topk_indices
=
topk_indices_buffer
[
chunk
.
token_start
:
chunk
.
token_end
,
:
topk_tokens
]
if
not
envs
.
USE_LIGHTOP_TOPK
:
torch
.
ops
.
_C
.
top_k_per_row_prefill
(
logits
,
chunk
.
cu_seqlen_ks
,
...
...
@@ -180,6 +181,17 @@ def sparse_attn_indexer(
logits
.
stride
(
1
),
topk_tokens
,
)
else
:
op
.
top_k_per_row_prefill
(
logits
,
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
if
has_decode
:
decode_metadata
=
attn_metadata
.
decode
...
...
@@ -230,6 +242,9 @@ def sparse_attn_indexer(
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[:
num_padded_tokens
,
:
topk_tokens
]
# if torch.distributed.get_rank() == 0:
# print(f"====[DEBUG] logits shape: {logits.shape}, next_n: {next_n}, topk_tokens size: {topk_tokens}")
if
not
envs
.
USE_LIGHTOP_TOPK
:
torch
.
ops
.
_C
.
top_k_per_row_decode
(
logits
,
next_n
,
...
...
@@ -240,7 +255,17 @@ def sparse_attn_indexer(
logits
.
stride
(
1
),
topk_tokens
,
)
else
:
op
.
top_k_per_row_decode
(
logits
,
next_n
,
decode_metadata
.
seq_lens
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
if
decode_metadata
.
requires_padding
:
# if padded, we need to unpack
# the topk indices removing padded tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment