Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
578977bb
Unverified
Commit
578977bb
authored
Feb 10, 2026
by
Pavani Majety
Committed by
GitHub
Feb 10, 2026
Browse files
[SM100] Resubmit FMHA FP8 prefill for MLA (#31195)
Signed-off-by:
Pavani Majety
<
pmajety@nvidia.com
>
parent
9615575a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
145 additions
and
23 deletions
+145
-23
tests/v1/attention/test_mla_backends.py
tests/v1/attention/test_mla_backends.py
+4
-3
vllm/config/attention.py
vllm/config/attention.py
+3
-0
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+138
-20
No files found.
tests/v1/attention/test_mla_backends.py
View file @
578977bb
...
@@ -27,7 +27,7 @@ from vllm.v1.attention.backend import CommonAttentionMetadata
...
@@ -27,7 +27,7 @@ from vllm.v1.attention.backend import CommonAttentionMetadata
from
vllm.v1.attention.backends.fa_utils
import
flash_attn_supports_mla
from
vllm.v1.attention.backends.fa_utils
import
flash_attn_supports_mla
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.ops.flashmla
import
is_flashmla_dense_supported
from
vllm.v1.attention.ops.flashmla
import
is_flashmla_dense_supported
from
vllm.v1.kv_cache_interface
import
Full
AttentionSpec
from
vllm.v1.kv_cache_interface
import
MLA
AttentionSpec
BACKENDS_TO_TEST
=
[
BACKENDS_TO_TEST
=
[
AttentionBackendEnum
.
CUTLASS_MLA
,
AttentionBackendEnum
.
CUTLASS_MLA
,
...
@@ -512,7 +512,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
...
@@ -512,7 +512,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def
run_attention_backend
(
def
run_attention_backend
(
backend
:
AttentionBackendEnum
,
backend
:
AttentionBackendEnum
,
kv_cache_spec
:
Full
AttentionSpec
,
kv_cache_spec
:
MLA
AttentionSpec
,
layer_names
:
list
[
str
],
layer_names
:
list
[
str
],
vllm_config
,
vllm_config
,
device
:
torch
.
device
,
device
:
torch
.
device
,
...
@@ -989,7 +989,7 @@ def test_backend_correctness(
...
@@ -989,7 +989,7 @@ def test_backend_correctness(
kv_cache
=
kv_cache_per_block_size
[
block_size
]
kv_cache
=
kv_cache_per_block_size
[
block_size
]
# Create kv_cache_spec with the correct block_size for this backend
# Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec
=
Full
AttentionSpec
(
backend_kv_cache_spec
=
MLA
AttentionSpec
(
block_size
=
block_size
,
block_size
=
block_size
,
num_kv_heads
=
vllm_config
.
model_config
.
get_num_kv_heads
(
num_kv_heads
=
vllm_config
.
model_config
.
get_num_kv_heads
(
vllm_config
.
parallel_config
vllm_config
.
parallel_config
...
@@ -997,6 +997,7 @@ def test_backend_correctness(
...
@@ -997,6 +997,7 @@ def test_backend_correctness(
head_size
=
vllm_config
.
model_config
.
get_head_size
(),
head_size
=
vllm_config
.
model_config
.
get_head_size
(),
dtype
=
vllm_config
.
model_config
.
dtype
,
dtype
=
vllm_config
.
model_config
.
dtype
,
sliding_window
=
vllm_config
.
model_config
.
get_sliding_window
(),
sliding_window
=
vllm_config
.
model_config
.
get_sliding_window
(),
cache_dtype_str
=
vllm_config
.
cache_config
.
cache_dtype
,
)
)
backend_output
=
run_attention_backend
(
backend_output
=
run_attention_backend
(
...
...
vllm/config/attention.py
View file @
578977bb
...
@@ -43,6 +43,9 @@ class AttentionConfig:
...
@@ -43,6 +43,9 @@ class AttentionConfig:
disable_flashinfer_q_quantization
:
bool
=
False
disable_flashinfer_q_quantization
:
bool
=
False
"""If set, when using fp8 kv, do not quantize Q to fp8."""
"""If set, when using fp8 kv, do not quantize Q to fp8."""
use_prefill_query_quantization
:
bool
=
False
"""If set, quantize query for attention in prefill."""
def
compute_hash
(
self
)
->
str
:
def
compute_hash
(
self
)
->
str
:
"""
"""
Provide a hash that uniquely identifies all the configs
Provide a hash that uniquely identifies all the configs
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
578977bb
...
@@ -1052,6 +1052,7 @@ class MLACommonPrefillMetadata:
...
@@ -1052,6 +1052,7 @@ class MLACommonPrefillMetadata:
query_seq_lens
:
torch
.
Tensor
|
None
=
None
query_seq_lens
:
torch
.
Tensor
|
None
=
None
workspace_buffer
:
torch
.
Tensor
|
None
=
None
workspace_buffer
:
torch
.
Tensor
|
None
=
None
q_data_type
:
torch
.
dtype
|
None
=
None
q_data_type
:
torch
.
dtype
|
None
=
None
output_dtype
:
torch
.
dtype
|
None
=
None
@
dataclass
@
dataclass
...
@@ -1145,6 +1146,7 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
...
@@ -1145,6 +1146,7 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
return
qk_nope_head_dim
==
128
and
qk_rope_head_dim
==
64
and
v_head_dim
==
128
return
qk_nope_head_dim
==
128
and
qk_rope_head_dim
==
64
and
v_head_dim
==
128
@
functools
.
cache
def
use_flashinfer_prefill
()
->
bool
:
def
use_flashinfer_prefill
()
->
bool
:
# For blackwell default to flashinfer prefill if it's available since
# For blackwell default to flashinfer prefill if it's available since
# it is faster than FA2.
# it is faster than FA2.
...
@@ -1162,6 +1164,7 @@ def use_flashinfer_prefill() -> bool:
...
@@ -1162,6 +1164,7 @@ def use_flashinfer_prefill() -> bool:
return
is_deepseek_r1_mla_compatible
(
vllm_config
)
return
is_deepseek_r1_mla_compatible
(
vllm_config
)
@
functools
.
cache
def
use_cudnn_prefill
()
->
bool
:
def
use_cudnn_prefill
()
->
bool
:
from
vllm.config
import
get_current_vllm_config
from
vllm.config
import
get_current_vllm_config
...
@@ -1174,6 +1177,7 @@ def use_cudnn_prefill() -> bool:
...
@@ -1174,6 +1177,7 @@ def use_cudnn_prefill() -> bool:
)
)
@
functools
.
cache
def
use_trtllm_ragged_deepseek_prefill
()
->
bool
:
def
use_trtllm_ragged_deepseek_prefill
()
->
bool
:
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
from
vllm.config
import
get_current_vllm_config
from
vllm.config
import
get_current_vllm_config
...
@@ -1210,6 +1214,27 @@ def get_mla_dims(model_config: ModelConfig) -> MLADims:
...
@@ -1210,6 +1214,27 @@ def get_mla_dims(model_config: ModelConfig) -> MLADims:
)
)
@
functools
.
cache
def
backend_supports_prefill_query_quantization
()
->
bool
:
"""Check if the selected MLA backend supports prefill query quantization.
Currently supported backends:
- FlashInfer prefill
- TRT-LLM ragged DeepSeek prefill
Not supported:
- cuDNN Prefill
- FlashAttention
- Non-GB200 devices (FP8 prefill requires device capability 100)
"""
# FP8 prefill query quantization requires GB200 (device capability 100)
# for the necessary FP8 kernels at the moment.
if
not
current_platform
.
is_device_capability_family
(
100
):
return
False
return
use_flashinfer_prefill
()
or
use_trtllm_ragged_deepseek_prefill
()
class
MLACommonMetadataBuilder
(
AttentionMetadataBuilder
[
M
]):
class
MLACommonMetadataBuilder
(
AttentionMetadataBuilder
[
M
]):
"""
"""
NOTE: Please read the comment at the top of the file before trying to
NOTE: Please read the comment at the top of the file before trying to
...
@@ -1262,6 +1287,40 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -1262,6 +1287,40 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
return
chunked_prefill_workspace_size
return
chunked_prefill_workspace_size
@
staticmethod
def
determine_prefill_query_data_type
(
vllm_config
:
VllmConfig
,
model_dtype
:
torch
.
dtype
,
)
->
torch
.
dtype
:
"""
Determine the query data type for prefill queries.
Return FP8 dtype if cache is FP8 and prefill query quantization
is enabled, else model dtype.
"""
use_fp8
=
(
vllm_config
.
cache_config
.
cache_dtype
.
startswith
(
"fp8"
)
and
vllm_config
.
attention_config
.
use_prefill_query_quantization
and
backend_supports_prefill_query_quantization
()
)
if
use_fp8
:
fp8_dtype
=
current_platform
.
fp8_dtype
()
logger
.
info_once
(
"FP8 prefill attention enabled: query data type is FP8"
,
scope
=
"local"
)
return
fp8_dtype
elif
vllm_config
.
attention_config
.
use_prefill_query_quantization
:
logger
.
info_once
(
"Unable to perform FP8 prefill attention when"
" use_prefill_query_quantization is enabled. Please"
" ensure that --kv-cache-dtype is set to fp8 and your prefill"
" backend is compatible with FP8 attention."
,
scope
=
"local"
,
)
return
model_dtype
return
model_dtype
def
__init__
(
def
__init__
(
self
,
self
,
kv_cache_spec
:
AttentionSpec
,
kv_cache_spec
:
AttentionSpec
,
...
@@ -1285,6 +1344,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -1285,6 +1344,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self
.
num_heads
=
self
.
model_config
.
get_num_attention_heads
(
parallel_config
)
self
.
num_heads
=
self
.
model_config
.
get_num_attention_heads
(
parallel_config
)
self
.
mla_dims
=
get_mla_dims
(
self
.
model_config
)
self
.
mla_dims
=
get_mla_dims
(
self
.
model_config
)
self
.
aot_schedule
=
current_platform
.
is_cuda
()
self
.
aot_schedule
=
current_platform
.
is_cuda
()
self
.
kv_cache_spec
=
kv_cache_spec
self
.
q_data_type
=
self
.
determine_prefill_query_data_type
(
vllm_config
,
self
.
model_config
.
dtype
)
try
:
try
:
self
.
dcp_world_size
=
get_dcp_group
().
world_size
self
.
dcp_world_size
=
get_dcp_group
().
world_size
self
.
dcp_rank
=
get_dcp_group
().
rank_in_group
self
.
dcp_rank
=
get_dcp_group
().
rank_in_group
...
@@ -1325,7 +1390,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -1325,7 +1390,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self
.
chunked_prefill_workspace_size
,
self
.
chunked_prefill_workspace_size
,
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_head_size
(),
),
),
dtype
=
self
.
model_config
.
d
type
,
dtype
=
self
.
q_data_
type
,
device
=
device
,
device
=
device
,
)
)
...
@@ -1435,7 +1500,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -1435,7 +1500,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale
=
self
.
_global_hyperparameters
.
sm_scale
,
sm_scale
=
self
.
_global_hyperparameters
.
sm_scale
,
window_left
=
self
.
_global_hyperparameters
.
window_left
,
window_left
=
self
.
_global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
_global_hyperparameters
.
logits_soft_cap
,
logits_soft_cap
=
self
.
_global_hyperparameters
.
logits_soft_cap
,
q_data_type
=
self
.
model_config
.
dtype
,
q_data_type
=
self
.
q_data_type
,
o_data_type
=
prefill
.
output_dtype
,
)
)
# Prepare context prefills
# Prepare context prefills
...
@@ -1454,7 +1520,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -1454,7 +1520,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale
=
self
.
_global_hyperparameters
.
sm_scale
,
sm_scale
=
self
.
_global_hyperparameters
.
sm_scale
,
window_left
=
self
.
_global_hyperparameters
.
window_left
,
window_left
=
self
.
_global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
_global_hyperparameters
.
logits_soft_cap
,
logits_soft_cap
=
self
.
_global_hyperparameters
.
logits_soft_cap
,
q_data_type
=
self
.
model_config
.
dtype
,
q_data_type
=
self
.
q_data_type
,
o_data_type
=
prefill
.
output_dtype
,
)
)
prefill
.
prefill_main
=
self
.
_fi_prefill_main
prefill
.
prefill_main
=
self
.
_fi_prefill_main
...
@@ -1709,6 +1776,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -1709,6 +1776,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc
=
prefill_query_start_loc
,
query_start_loc
=
prefill_query_start_loc
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
chunked_context
=
chunked_context_metadata
,
chunked_context
=
chunked_context_metadata
,
output_dtype
=
self
.
model_config
.
dtype
,
q_data_type
=
self
.
q_data_type
,
)
)
if
self
.
_use_cudnn_prefill
:
if
self
.
_use_cudnn_prefill
:
...
@@ -1894,7 +1963,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1894,7 +1963,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self
.
kv_b_proj
=
kv_b_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
indexer
=
indexer
self
.
indexer
=
indexer
self
.
q_pad_num_heads
=
q_pad_num_heads
self
.
q_pad_num_heads
=
q_pad_num_heads
self
.
supports_quant_query_input
=
True
self
.
supports_quant_query_input
=
True
# Use flashinfer's optimized concat_mla_k kernel when available.
# Use flashinfer's optimized concat_mla_k kernel when available.
...
@@ -2129,6 +2197,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -2129,6 +2197,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
assert
prefill
.
query_seq_lens
is
not
None
assert
prefill
.
query_seq_lens
is
not
None
assert
prefill
.
workspace_buffer
is
not
None
assert
prefill
.
workspace_buffer
is
not
None
# allocate BF16 / FP16 output tensor for TRT-LLM ragged attention
out
=
torch
.
empty
(
q
.
shape
[
0
],
q
.
shape
[
1
],
v
.
shape
[
2
],
device
=
q
.
device
,
dtype
=
prefill
.
output_dtype
,
)
ret
=
trtllm_ragged_attention_deepseek
(
ret
=
trtllm_ragged_attention_deepseek
(
query
=
q
,
query
=
q
,
...
@@ -2148,6 +2224,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -2148,6 +2224,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
enable_pdl
=
False
,
enable_pdl
=
False
,
is_causal
=
True
,
is_causal
=
True
,
return_lse
=
return_softmax_lse
,
return_lse
=
return_softmax_lse
,
out
=
out
,
)
)
if
isinstance
(
ret
,
tuple
):
if
isinstance
(
ret
,
tuple
):
...
@@ -2170,7 +2247,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -2170,7 +2247,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q
.
shape
[
1
],
q
.
shape
[
1
],
v
.
shape
[
2
],
v
.
shape
[
2
],
device
=
q
.
device
,
device
=
q
.
device
,
dtype
=
q
.
dtype
,
dtype
=
prefill
.
output_
dtype
,
)
)
prefill
.
workspace_buffer
.
fill_
(
0
)
prefill
.
workspace_buffer
.
fill_
(
0
)
...
@@ -2240,29 +2317,59 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -2240,29 +2317,59 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
prefill_metadata
=
attn_metadata
.
prefill
prefill_metadata
=
attn_metadata
.
prefill
assert
prefill_metadata
.
chunked_context
is
not
None
assert
prefill_metadata
.
chunked_context
is
not
None
use_fp8_prefill
=
prefill_metadata
.
q_data_type
==
current_platform
.
fp8_dtype
()
output
=
None
output
=
None
iters
=
len
(
prefill_metadata
.
chunked_context
.
seq_tot
)
iters
=
len
(
prefill_metadata
.
chunked_context
.
seq_tot
)
workspace
=
prefill_metadata
.
chunked_context
.
workspace
workspace
=
prefill_metadata
.
chunked_context
.
workspace
if
use_fp8_prefill
:
q
=
q
.
to
(
prefill_metadata
.
q_data_type
)
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
toks
=
prefill_metadata
.
chunked_context
.
seq_tot
[
i
]
toks
=
prefill_metadata
.
chunked_context
.
seq_tot
[
i
]
ops
.
gather_and_maybe_dequant_cache
(
if
not
use_fp8_prefill
:
src_cache
=
kv_c_and_k_pe_cache
,
ops
.
gather_and_maybe_dequant_cache
(
dst
=
workspace
,
src_cache
=
kv_c_and_k_pe_cache
,
block_table
=
prefill_metadata
.
block_table
,
dst
=
workspace
,
cu_seq_lens
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
block_table
=
prefill_metadata
.
block_table
,
token_to_seq
=
prefill_metadata
.
chunked_context
.
token_to_seq
[
i
],
cu_seq_lens
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
num_tokens
=
prefill_metadata
.
chunked_context
.
chunk_total_token
[
i
],
token_to_seq
=
prefill_metadata
.
chunked_context
.
token_to_seq
[
i
],
kv_cache_dtype
=
self
.
kv_cache_dtype
,
num_tokens
=
prefill_metadata
.
chunked_context
.
chunk_total_token
[
i
],
scale
=
k_scale
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
seq_starts
=
prefill_metadata
.
chunked_context
.
starts
[
i
],
scale
=
k_scale
,
)
seq_starts
=
prefill_metadata
.
chunked_context
.
starts
[
i
],
)
else
:
# FP8 path: gather cache without dequantization
ops
.
cp_gather_cache
(
src_cache
=
kv_c_and_k_pe_cache
,
dst
=
workspace
,
block_table
=
prefill_metadata
.
block_table
,
cu_seq_lens
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
batch_size
=
attn_metadata
.
num_prefills
,
seq_starts
=
prefill_metadata
.
chunked_context
.
starts
[
i
],
)
# Extract kv_c_normed from workspace
kv_c_normed
=
workspace
[:
toks
][...,
:
self
.
kv_lora_rank
]
kv_c_normed
=
workspace
[:
toks
][...,
:
self
.
kv_lora_rank
]
k_pe
=
workspace
[:
toks
][...,
self
.
kv_lora_rank
:].
unsqueeze
(
1
)
# When FP8 weights are used without FP8 prefill, kv_b_proj expects
# model dtype input and will quantize internally.
if
(
use_fp8_prefill
or
self
.
kv_b_proj
.
weight
.
dtype
!=
current_platform
.
fp8_dtype
()
):
kv_c_normed
=
kv_c_normed
.
to
(
self
.
kv_b_proj
.
weight
.
dtype
)
k_pe
=
workspace
[:
toks
][...,
self
.
kv_lora_rank
:].
unsqueeze
(
1
)
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
)
# To Do: Use epilogue of kv_b_proj to generate fp8 kv_nope.
if
use_fp8_prefill
:
kv_nope
=
kv_nope
.
to
(
prefill_metadata
.
q_data_type
)
k_pe
=
k_pe
.
to
(
prefill_metadata
.
q_data_type
)
k_nope
,
v
=
kv_nope
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k_nope
,
v
=
kv_nope
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
self
.
_concat_k_nope_k_pe
(
k_nope
,
k_pe
)
k
=
self
.
_concat_k_nope_k_pe
(
k_nope
,
k_pe
)
...
@@ -2412,16 +2519,27 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -2412,16 +2519,27 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
assert
attn_metadata
.
prefill
is
not
None
assert
attn_metadata
.
prefill
is
not
None
assert
self
.
dcp_world_size
!=
-
1
assert
self
.
dcp_world_size
!=
-
1
has_context
=
attn_metadata
.
prefill
.
chunked_context
is
not
None
prefill_metadata
=
attn_metadata
.
prefill
use_fp8_prefill
=
prefill_metadata
.
q_data_type
==
current_platform
.
fp8_dtype
()
# Convert q to FP8 if FP8 prefill attention is enabled
if
use_fp8_prefill
:
q
=
q
.
to
(
prefill_metadata
.
q_data_type
)
has_context
=
prefill_metadata
.
chunked_context
is
not
None
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
)
k_nope
,
v
=
kv_nope
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k_nope
,
v
=
kv_nope
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
self
.
_concat_k_nope_k_pe
(
k_nope
,
k_pe
)
k
=
self
.
_concat_k_nope_k_pe
(
k_nope
,
k_pe
)
if
use_fp8_prefill
:
k
=
k
.
to
(
prefill_metadata
.
q_data_type
)
v
=
v
.
to
(
prefill_metadata
.
q_data_type
)
output_prefill
=
self
.
_run_prefill_new_tokens
(
output_prefill
=
self
.
_run_prefill_new_tokens
(
prefill
=
attn
_metadata
.
prefill
,
prefill
=
prefill
_metadata
,
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v
,
v
=
v
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment