Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f441aca2
Commit
f441aca2
authored
Jan 05, 2026
by
zhuwenwen
Browse files
update mqa_logits and paged_mqa_logits
parent
cc7715fd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
97 additions
and
67 deletions
+97
-67
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+97
-67
No files found.
vllm/model_executor/models/deepseek_v2.py
View file @
f441aca2
...
@@ -83,6 +83,7 @@ from vllm import _custom_ops as ops
...
@@ -83,6 +83,7 @@ from vllm import _custom_ops as ops
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils
import
W8a8GetCacheJSON
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
import
lightop
from
lightop
import
op
,
gemmopt
from
lightop
import
op
,
gemmopt
else
:
else
:
from
vllm.utils.deep_gemm
import
fp8_mqa_logits
,
fp8_paged_mqa_logits
from
vllm.utils.deep_gemm
import
fp8_mqa_logits
,
fp8_paged_mqa_logits
...
@@ -601,33 +602,35 @@ def sparse_attn_indexer(
...
@@ -601,33 +602,35 @@ def sparse_attn_indexer(
has_prefill
=
attn_metadata
.
num_prefills
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
ops
.
indexer_k_quant_and_cache
(
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
k
,
ops
.
indexer_k_quant_and_cache
(
kv_cache
,
k
,
slot_mapping
,
kv_cache
,
quant_block_size
,
slot_mapping
,
scale_fmt
,
quant_block_size
,
)
scale_fmt
,
)
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
if
has_prefill
:
if
has_prefill
:
prefill_metadata
=
attn_metadata
.
prefill
prefill_metadata
=
attn_metadata
.
prefill
for
chunk
in
prefill_metadata
.
chunks
:
for
chunk
in
prefill_metadata
.
chunks
:
k_fp8
=
torch
.
empty
([
chunk
.
total_seq_lens
,
head_dim
],
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
k_fp8
=
torch
.
empty
([
chunk
.
total_seq_lens
,
head_dim
],
device
=
k
.
device
,
device
=
k
.
device
,
dtype
=
torch
.
float8_e4m3fn
)
dtype
=
torch
.
float8_e4m3fn
)
k_scale
=
torch
.
empty
([
chunk
.
total_seq_lens
,
1
],
k_scale
=
torch
.
empty
([
chunk
.
total_seq_lens
,
1
],
device
=
k
.
device
,
device
=
k
.
device
,
dtype
=
torch
.
float32
)
dtype
=
torch
.
float32
)
cp_gather_indexer_k_quant_cache
(
cp_gather_indexer_k_quant_cache
(
kv_cache
,
kv_cache
,
k_fp8
,
k_fp8
,
k_scale
,
k_scale
,
chunk
.
block_table
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
chunk
.
cu_seq_lens
,
chunk
.
num_reqs
,
chunk
.
num_reqs
,
)
)
if
not
current_platform
.
is_rocm
():
logits
=
fp8_mqa_logits
(
logits
=
fp8_mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
,
k_scale
),
(
k_fp8
,
k_scale
),
...
@@ -637,12 +640,18 @@ def sparse_attn_indexer(
...
@@ -637,12 +640,18 @@ def sparse_attn_indexer(
)
)
else
:
else
:
logits
=
op
.
mqa_logits
(
logits
=
op
.
mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
].
half
(),
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
.
half
(),
k_scale
),
k
,
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
weights
[
chunk
.
token_start
:
chunk
.
token_end
]
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
else
weights
[
chunk
.
token_start
:
chunk
.
token_end
].
to
(
torch
.
float32
),
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
chunk
.
cu_seqlen_ke
,
)
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
].
shape
[
0
],
k
.
shape
[
0
],
64
,
128
,
True
,
)
topk_indices
=
logits
.
topk
(
min
(
topk_tokens
,
logits
.
shape
[
-
1
]),
topk_indices
=
logits
.
topk
(
min
(
topk_tokens
,
logits
.
shape
[
-
1
]),
dim
=-
1
)[
1
]
dim
=-
1
)[
1
]
topk_indices
-=
chunk
.
cu_seqlen_ks
[:,
None
]
topk_indices
-=
chunk
.
cu_seqlen_ks
[:,
None
]
...
@@ -692,14 +701,15 @@ def sparse_attn_indexer(
...
@@ -692,14 +701,15 @@ def sparse_attn_indexer(
)
)
else
:
else
:
logits
=
gemmopt
.
paged_mqa_logits
(
logits
=
gemmopt
.
paged_mqa_logits
(
padded_q_fp8_decode_tokens
.
half
()
,
padded_q_fp8_decode_tokens
,
kv_cache
.
half
(
),
kv_cache
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
else
kv_cache
.
to
(
torch
.
bfloat16
),
weights
[:
num_padded_tokens
]
,
weights
[:
num_padded_tokens
]
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
else
weights
[:
num_padded_tokens
].
to
(
torch
.
float32
),
decode_metadata
.
seq_lens
,
decode_metadata
.
seq_lens
,
decode_metadata
.
block_table
,
decode_metadata
.
block_table
,
decode_metadata
.
schedule_metadata
,
decode_metadata
.
schedule_metadata
,
max_context_len
=
max_model_len
,
max_model_len
,
)
)
# padded query len
# padded query len
current_device
=
padded_q_fp8_decode_tokens
.
device
current_device
=
padded_q_fp8_decode_tokens
.
device
padded_num_tokens
=
batch_size
*
next_n
padded_num_tokens
=
batch_size
*
next_n
...
@@ -753,12 +763,13 @@ def sparse_attn_indexer_fake(
...
@@ -753,12 +763,13 @@ def sparse_attn_indexer_fake(
# profile run
# profile run
# NOTE(Chen): create the max possible flattened_kv. So that
# NOTE(Chen): create the max possible flattened_kv. So that
# profile_run can get correct memory usage.
# profile_run can get correct memory usage.
_flattened_kv
=
torch
.
empty
([
total_seq_lens
,
head_dim
+
4
],
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
device
=
k
.
device
,
_flattened_kv
=
torch
.
empty
([
total_seq_lens
,
head_dim
+
4
],
dtype
=
torch
.
uint8
)
device
=
k
.
device
,
_k_fp8
=
_flattened_kv
[...,
:
head_dim
].
view
(
dtype
=
torch
.
uint8
)
torch
.
float8_e4m3fn
).
contiguous
()
_k_fp8
=
_flattened_kv
[...,
:
head_dim
].
view
(
_k_scale
=
_flattened_kv
[...,
head_dim
:].
view
(
torch
.
float32
).
contiguous
()
torch
.
float8_e4m3fn
).
contiguous
()
_k_scale
=
_flattened_kv
[...,
head_dim
:].
view
(
torch
.
float32
).
contiguous
()
return
topk_indices_buffer
return
topk_indices_buffer
...
@@ -845,35 +856,54 @@ class Indexer(nn.Module):
...
@@ -845,35 +856,54 @@ class Indexer(nn.Module):
k
=
torch
.
cat
([
k_pe
.
squeeze
(
1
),
k_nope
],
dim
=-
1
)
k
=
torch
.
cat
([
k_pe
.
squeeze
(
1
),
k_nope
],
dim
=-
1
)
# we only quant q here since k quant is fused with cache insertion
# we only quant q here since k quant is fused with cache insertion
q
=
q
.
view
(
-
1
,
self
.
head_dim
)
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
q_fp8
,
q_scale
=
per_token_group_quant_fp8
(
q
,
q
=
q
.
view
(
-
1
,
self
.
head_dim
)
self
.
quant_block_size
,
q_fp8
,
q_scale
=
per_token_group_quant_fp8
(
q
,
column_major_scales
=
False
,
self
.
quant_block_size
,
use_ue8m0
=
self
.
scale_fmt
column_major_scales
=
False
,
is
not
None
)
use_ue8m0
=
self
.
scale_fmt
q_fp8
=
q_fp8
.
view
(
-
1
,
self
.
n_head
,
self
.
head_dim
)
is
not
None
)
q_scale
=
q_scale
.
view
(
-
1
,
self
.
n_head
,
1
)
q_fp8
=
q_fp8
.
view
(
-
1
,
self
.
n_head
,
self
.
head_dim
)
q_scale
=
q_scale
.
view
(
-
1
,
self
.
n_head
,
1
)
weights
,
_
=
self
.
weights_proj
(
hidden_states
)
weights
,
_
=
self
.
weights_proj
(
hidden_states
)
weights
=
weights
.
unsqueeze
(
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
-
1
)
*
q_scale
*
self
.
softmax_scale
*
self
.
n_head
**-
0.5
weights
=
weights
.
unsqueeze
(
weights
=
weights
.
squeeze
(
-
1
)
-
1
)
*
q_scale
*
self
.
softmax_scale
*
self
.
n_head
**-
0.5
weights
=
weights
.
squeeze
(
-
1
)
return
torch
.
ops
.
vllm
.
sparse_attn_indexer
(
hidden_states
,
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
self
.
k_cache
.
prefix
,
return
torch
.
ops
.
vllm
.
sparse_attn_indexer
(
self
.
k_cache
.
kv_cache
[
0
],
hidden_states
,
q_fp8
,
self
.
k_cache
.
prefix
,
k
,
self
.
k_cache
.
kv_cache
[
0
],
weights
,
q_fp8
,
self
.
quant_block_size
,
k
,
self
.
scale_fmt
,
weights
,
self
.
topk_tokens
,
self
.
quant_block_size
,
self
.
head_dim
,
self
.
scale_fmt
,
self
.
max_model_len
,
self
.
topk_tokens
,
self
.
max_total_seq_len
,
self
.
head_dim
,
self
.
topk_indices_buffer
,
self
.
max_model_len
,
)
self
.
max_total_seq_len
,
self
.
topk_indices_buffer
,
)
else
:
return
torch
.
ops
.
vllm
.
sparse_attn_indexer
(
hidden_states
,
self
.
k_cache
.
prefix
,
self
.
k_cache
.
kv_cache
[
0
],
q
,
k
,
weights
,
self
.
quant_block_size
,
self
.
scale_fmt
,
self
.
topk_tokens
,
self
.
head_dim
,
self
.
max_model_len
,
self
.
max_total_seq_len
,
self
.
topk_indices_buffer
,
)
class
DeepseekV2MLAAttention
(
nn
.
Module
):
class
DeepseekV2MLAAttention
(
nn
.
Module
):
...
@@ -1583,4 +1613,4 @@ def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config,
...
@@ -1583,4 +1613,4 @@ def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config,
for
i
in
range
(
config
.
num_nextn_predict_layers
):
for
i
in
range
(
config
.
num_nextn_predict_layers
):
if
weight_name
.
startswith
(
f
"model.layers.
{
layer_idx
+
i
}
."
):
if
weight_name
.
startswith
(
f
"model.layers.
{
layer_idx
+
i
}
."
):
return
layer_idx
+
i
return
layer_idx
+
i
return
None
return
None
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment