Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6a86ea6d
Commit
6a86ea6d
authored
Apr 09, 2026
by
wanghl6
Browse files
[DSA][BUGFIX]解决mqa_logits开PC时大bs导致的oom问题
parent
1edffefe
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
10 deletions
+6
-10
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+6
-10
No files found.
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
6a86ea6d
...
@@ -20,10 +20,6 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_tri
...
@@ -20,10 +20,6 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_tri
from
vllm.v1.worker.workspace
import
current_workspace_manager
from
vllm.v1.worker.workspace
import
current_workspace_manager
from
lightop
import
op
,
gemmopt
from
lightop
import
op
,
gemmopt
from
vllm.attention.utils.kv_transfer_utils
import
(
maybe_transfer_kv_layer
,
)
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
():
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
elif
current_platform
.
is_xpu
():
elif
current_platform
.
is_xpu
():
...
@@ -31,10 +27,10 @@ elif current_platform.is_xpu():
...
@@ -31,10 +27,10 @@ elif current_platform.is_xpu():
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
maybe_transfer_kv_layer
def
sparse_attn_indexer
(
def
sparse_attn_indexer
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
layer_name
:
str
,
k_cache_prefix
:
str
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
@@ -60,7 +56,7 @@ def sparse_attn_indexer(
...
@@ -60,7 +56,7 @@ def sparse_attn_indexer(
)
)
return
sparse_attn_indexer_fake
(
return
sparse_attn_indexer_fake
(
hidden_states
,
hidden_states
,
layer_name
,
k_cache_prefix
,
kv_cache
,
kv_cache
,
q_fp8
,
q_fp8
,
k
,
k
,
...
@@ -73,9 +69,9 @@ def sparse_attn_indexer(
...
@@ -73,9 +69,9 @@ def sparse_attn_indexer(
total_seq_lens
,
total_seq_lens
,
topk_indices_buffer
,
topk_indices_buffer
,
)
)
attn_metadata
=
attn_metadata
[
layer_name
]
attn_metadata
=
attn_metadata
[
k_cache_prefix
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
[:
attn_metadata
.
num_kv_actual_tokens
]
slot_mapping
=
attn_metadata
.
slot_mapping
has_decode
=
attn_metadata
.
num_decodes
>
0
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
...
@@ -322,7 +318,7 @@ def sparse_attn_indexer(
...
@@ -322,7 +318,7 @@ def sparse_attn_indexer(
def
sparse_attn_indexer_fake
(
def
sparse_attn_indexer_fake
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
layer_name
:
str
,
k_cache_prefix
:
str
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment