Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
25ec6a34
Commit
25ec6a34
authored
Jan 06, 2026
by
zhuwenwen
Browse files
update mqa_logits and paged_mqa_logits
parent
8a4a6fd8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
36 deletions
+45
-36
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+45
-36
No files found.
vllm/model_executor/models/deepseek_v2.py
View file @
25ec6a34
...
...
@@ -674,13 +674,14 @@ def sparse_attn_indexer(
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
ops
.
indexer_k_quant_and_cache
(
k
,
kv_cache
,
slot_mapping
,
quant_block_size
,
scale_fmt
,
)
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
ops
.
indexer_k_quant_and_cache
(
k
,
kv_cache
,
slot_mapping
,
quant_block_size
,
scale_fmt
,
)
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
if
has_prefill
:
...
...
@@ -694,15 +695,16 @@ def sparse_attn_indexer(
)
for
chunk
in
prefill_metadata
.
chunks
:
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
ops
.
cp_gather_indexer_k_quant_cache
(
kv_cache
,
k_fp8
,
k_scale
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
)
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
ops
.
cp_gather_indexer_k_quant_cache
(
kv_cache
,
k_fp8
,
k_scale
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
)
fp8_mqa_logits_func
=
fp8_mqa_logits
if
current_platform
.
is_rocm
():
# from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits
...
...
@@ -712,10 +714,15 @@ def sparse_attn_indexer(
logits
=
fp8_mqa_logits_func
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
)),
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
))
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
else
k_fp8
,
weights
[
chunk
.
token_start
:
chunk
.
token_end
]
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
else
weights
[
chunk
.
token_start
:
chunk
.
token_end
].
to
(
torch
.
float32
)
,
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
].
shape
[
0
],
k
.
shape
[
0
],
64
,
128
,
True
,
)
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[
...
...
@@ -766,11 +773,11 @@ def sparse_attn_indexer(
logits
=
fp8_paged_mqa_logits_func
(
padded_q_fp8_decode_tokens
,
kv_cache
,
weights
[:
num_padded_tokens
],
weights
[:
num_padded_tokens
]
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
else
weights
[:
num_padded_tokens
].
to
(
torch
.
float32
)
,
decode_metadata
.
seq_lens
,
decode_metadata
.
block_table
,
decode_metadata
.
schedule_metadata
,
max_model_len
=
max_model_len
,
max_model_len
,
)
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[:
num_decode_tokens
,
:
topk_tokens
]
...
...
@@ -876,8 +883,8 @@ class Indexer(nn.Module):
# where we store value in fp8 and scale in fp32
# per self.quant_block_size element
self
.
k_cache
=
DeepseekV32IndexerCache
(
head_dim
=
self
.
head_dim
+
self
.
head_dim
//
self
.
quant_block_size
*
4
,
dtype
=
torch
.
uint8
,
head_dim
=
self
.
head_dim
+
self
.
head_dim
//
self
.
quant_block_size
*
4
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
else
self
.
head_dim
,
dtype
=
torch
.
uint8
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
else
torch
.
bfloat16
,
prefix
=
f
"
{
prefix
}
.k_cache"
,
cache_config
=
cache_config
,
)
...
...
@@ -907,27 +914,29 @@ class Indexer(nn.Module):
k
=
torch
.
cat
([
k_pe
.
squeeze
((
0
,
2
)),
k_nope
],
dim
=-
1
)
# we only quant q here since k quant is fused with cache insertion
q
=
q
.
view
(
-
1
,
self
.
head_dim
)
q_fp8
,
q_scale
=
per_token_group_quant_fp8
(
q
,
self
.
quant_block_size
,
column_major_scales
=
False
,
use_ue8m0
=
self
.
scale_fmt
is
not
None
,
)
q_fp8
=
q_fp8
.
view
(
-
1
,
self
.
n_head
,
self
.
head_dim
)
q_scale
=
q_scale
.
view
(
-
1
,
self
.
n_head
,
1
)
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
q
=
q
.
view
(
-
1
,
self
.
head_dim
)
q_fp8
,
q_scale
=
per_token_group_quant_fp8
(
q
,
self
.
quant_block_size
,
column_major_scales
=
False
,
use_ue8m0
=
self
.
scale_fmt
is
not
None
,
)
q_fp8
=
q_fp8
.
view
(
-
1
,
self
.
n_head
,
self
.
head_dim
)
q_scale
=
q_scale
.
view
(
-
1
,
self
.
n_head
,
1
)
weights
,
_
=
self
.
weights_proj
(
hidden_states
)
weights
=
(
weights
.
unsqueeze
(
-
1
)
*
q_scale
*
self
.
softmax_scale
*
self
.
n_head
**-
0.5
)
weights
=
weights
.
squeeze
(
-
1
)
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
weights
=
(
weights
.
unsqueeze
(
-
1
)
*
q_scale
*
self
.
softmax_scale
*
self
.
n_head
**-
0.5
)
weights
=
weights
.
squeeze
(
-
1
)
return
torch
.
ops
.
vllm
.
sparse_attn_indexer
(
hidden_states
,
self
.
k_cache
.
prefix
,
self
.
k_cache
.
kv_cache
[
0
],
q_fp8
,
q_fp8
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
else
q
,
k
,
weights
,
self
.
quant_block_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment