Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1edffefe
Commit
1edffefe
authored
Apr 09, 2026
by
wanghl6
Browse files
[BUGFIX]解决mqa_logits大bs导致的oom问题
parent
bcb2ba6c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
98 additions
and
64 deletions
+98
-64
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+98
-64
No files found.
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
1edffefe
...
@@ -75,7 +75,7 @@ def sparse_attn_indexer(
...
@@ -75,7 +75,7 @@ def sparse_attn_indexer(
)
)
attn_metadata
=
attn_metadata
[
layer_name
]
attn_metadata
=
attn_metadata
[
layer_name
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
slot_mapping
=
attn_metadata
.
slot_mapping
[:
attn_metadata
.
num_kv_actual_tokens
]
has_decode
=
attn_metadata
.
num_decodes
>
0
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
...
@@ -116,14 +116,6 @@ def sparse_attn_indexer(
...
@@ -116,14 +116,6 @@ def sparse_attn_indexer(
chunk
.
block_table
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
chunk
.
cu_seq_lens
,
)
)
logits
=
fp8_mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
).
flatten
()),
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
)
elif
get_gcn_arch_name
()
==
"gfx938"
:
elif
get_gcn_arch_name
()
==
"gfx938"
:
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
...
@@ -134,19 +126,6 @@ def sparse_attn_indexer(
...
@@ -134,19 +126,6 @@ def sparse_attn_indexer(
chunk
.
block_table
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
chunk
.
cu_seq_lens
,
)
)
logits
=
op
.
mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
k_fp8
,
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
].
shape
[
0
],
k_fp8
.
shape
[
0
],
q_fp8
.
shape
[
1
],
q_fp8
.
shape
[
2
],
k_scale
.
view
(
torch
.
float32
).
flatten
(),
True
)
else
:
else
:
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
...
@@ -156,46 +135,104 @@ def sparse_attn_indexer(
...
@@ -156,46 +135,104 @@ def sparse_attn_indexer(
chunk
.
block_table
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
chunk
.
cu_seq_lens
,
)
)
logits
=
op
.
mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
k_fp8
,
weights
[
chunk
.
token_start
:
chunk
.
token_end
].
to
(
torch
.
float32
),
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
].
shape
[
0
],
k_fp8
.
shape
[
0
],
q_fp8
.
shape
[
1
],
q_fp8
.
shape
[
2
],
None
,
True
)
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[
q_all
=
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
]
chunk
.
token_start
:
chunk
.
token_end
,
:
topk_tokens
weights_all
=
weights
[
chunk
.
token_start
:
chunk
.
token_end
]
]
ks_all
=
chunk
.
cu_seqlen_ks
if
not
envs
.
USE_LIGHTOP_TOPK
:
ke_all
=
chunk
.
cu_seqlen_ke
torch
.
ops
.
_C
.
top_k_per_row_prefill
(
logits
,
num_q
=
q_all
.
shape
[
0
]
chunk
.
cu_seqlen_ks
,
num_k
=
k_fp8
.
shape
[
0
]
chunk
.
cu_seqlen_ke
,
topk_indices
,
MAX_ELEMENTS
=
1024
*
1024
*
1024
# 4GB
num_rows
,
if
(
num_q
<=
65536
and
num_k
<=
65536
):
# if num_q <= 65536 and num_k <= 65536 and (num_q * num_k <= MAX_ELEMENTS):
logits
.
stride
(
0
),
MAX_Q_CHUNK
=
max
(
1
,
num_q
)
logits
.
stride
(
1
),
topk_tokens
,
)
else
:
else
:
op
.
top_k_per_row_prefill
(
MAX_Q_CHUNK
=
max
(
1024
,
MAX_ELEMENTS
//
max
(
1
,
num_k
))
logits
,
MAX_Q_CHUNK
=
min
(
MAX_Q_CHUNK
,
max
(
1
,
num_q
))
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
#存储q的起始和终止地址
topk_indices
,
slices
=
[]
num_rows
,
logits
.
stride
(
0
),
for
start_idx
in
range
(
0
,
num_q
,
MAX_Q_CHUNK
):
logits
.
stride
(
1
),
end_idx
=
min
(
start_idx
+
MAX_Q_CHUNK
,
num_q
)
topk_tokens
,
slices
.
append
((
start_idx
,
end_idx
))
)
for
q_start
,
q_end
in
slices
:
if
q_end
<=
q_start
:
continue
q_slice
=
q_all
[
q_start
:
q_end
]
weights_slice
=
weights_all
[
q_start
:
q_end
]
ks_slice
=
ks_all
[
q_start
:
q_end
]
ke_slice
=
ke_all
[
q_start
:
q_end
]
if
not
current_platform
.
is_rocm
():
logits_slice
=
fp8_mqa_logits
(
q_slice
,
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
).
flatten
()),
weights_slice
,
ks_slice
,
ke_slice
,
)
elif
get_gcn_arch_name
()
==
"gfx938"
:
logits_slice
=
op
.
mqa_logits
(
q_slice
,
k_fp8
,
weights_slice
,
ks_slice
,
ke_slice
,
q_slice
.
shape
[
0
],
k_fp8
.
shape
[
0
],
q_slice
.
shape
[
1
],
q_slice
.
shape
[
2
],
k_scale
.
view
(
torch
.
float32
).
flatten
(),
True
)
else
:
logits_slice
=
op
.
mqa_logits
(
q_slice
,
k_fp8
,
weights_slice
.
to
(
torch
.
float32
),
ks_slice
,
ke_slice
,
q_slice
.
shape
[
0
],
k_fp8
.
shape
[
0
],
q_slice
.
shape
[
1
],
q_slice
.
shape
[
2
],
None
,
True
)
num_rows_slice
=
logits_slice
.
shape
[
0
]
topk_indices_slice
=
topk_indices_buffer
[
chunk
.
token_start
+
q_start
:
chunk
.
token_start
+
q_end
,
:
topk_tokens
]
if
not
envs
.
USE_LIGHTOP_TOPK
:
torch
.
ops
.
_C
.
top_k_per_row_prefill
(
logits_slice
,
ks_slice
,
ke_slice
,
topk_indices_slice
,
num_rows_slice
,
logits_slice
.
stride
(
0
),
logits_slice
.
stride
(
1
),
topk_tokens
,
)
else
:
op
.
top_k_per_row_prefill
(
logits_slice
,
ks_slice
,
ke_slice
,
topk_indices_slice
,
num_rows_slice
,
logits_slice
.
stride
(
0
),
logits_slice
.
stride
(
1
),
topk_tokens
,
)
if
has_decode
:
if
has_decode
:
decode_metadata
=
attn_metadata
.
decode
decode_metadata
=
attn_metadata
.
decode
...
@@ -242,7 +279,6 @@ def sparse_attn_indexer(
...
@@ -242,7 +279,6 @@ def sparse_attn_indexer(
max_model_len
,
max_model_len
,
)
)
num_rows
=
logits
.
shape
[
0
]
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[:
num_padded_tokens
,
:
topk_tokens
]
topk_indices
=
topk_indices_buffer
[:
num_padded_tokens
,
:
topk_tokens
]
...
@@ -423,6 +459,4 @@ class SparseAttnIndexer(CustomOp):
...
@@ -423,6 +459,4 @@ class SparseAttnIndexer(CustomOp):
self
.
max_model_len
,
self
.
max_model_len
,
self
.
max_total_seq_len
,
self
.
max_total_seq_len
,
self
.
topk_indices_buffer
,
self
.
topk_indices_buffer
,
)
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment