Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
347fc09c
Commit
347fc09c
authored
Dec 17, 2025
by
zhuwenwen
Browse files
Merge branch 'v0.9.2-dev-nmz' into v0.9.2-dev
parents
ffcc47b7
3e191138
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
414 additions
and
219 deletions
+414
-219
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+32
-15
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+1
-2
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+74
-0
vllm/envs.py
vllm/envs.py
+5
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+4
-4
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+239
-172
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+4
-3
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+55
-23
No files found.
vllm/attention/backends/flashmla.py
View file @
347fc09c
...
@@ -16,7 +16,10 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
...
@@ -16,7 +16,10 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonState
)
MLACommonState
)
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
get_mla_metadata
,
get_mla_metadata
,
flash_mla_with_kvcache_fp8
,
get_mla_decoding_metadata_dense_fp8
,
is_flashmla_supported
)
is_flashmla_supported
)
from
vllm
import
envs
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
...
@@ -93,7 +96,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
...
@@ -93,7 +96,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self
.
num_q_heads
,
self
.
num_q_heads
,
1
,
# MQA for the decode path
1
,
# MQA for the decode path
)
)
return
m
return
m
...
@@ -222,6 +224,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -222,6 +224,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashMLAMetadata
,
attn_metadata
:
FlashMLAMetadata
,
q_scale
=
None
,
k_scale
=
None
,
k_scale
=
None
,
kv_cache_dtype
=
"auto"
,
kv_cache_dtype
=
"auto"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -233,6 +236,21 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -233,6 +236,21 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
.
to
(
torch
.
float8_e4m3fn
),
k_cache
=
kv_c_and_k_pe_cache
.
view
(
torch
.
float8_e4m3fn
).
unsqueeze
(
-
2
),
# Add head dim of 1
block_table
=
decode_meta
.
block_tables
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
decode_meta
.
decode_tile_scheduler_metadata
,
num_splits
=
decode_meta
.
decode_num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
q_scale
,
descale_k
=
k_scale
,
)
else
:
o
,
_
=
flash_mla_with_kvcache
(
o
,
_
=
flash_mla_with_kvcache
(
q
=
q
,
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
...
@@ -246,5 +264,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -246,5 +264,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
k_scale
=
k_scale
,
k_scale
=
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
)
)
return
self
.
_v_up_proj
(
o
)
return
self
.
_v_up_proj
(
o
)
vllm/attention/backends/mla/common.py
View file @
347fc09c
...
@@ -1404,6 +1404,5 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1404,6 +1404,5 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
output
[
num_prefill_tokens
:]
=
self
.
_forward_decode
(
output
[
num_prefill_tokens
:]
=
self
.
_forward_decode
(
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
,
layer
.
_k_scale
,
self
.
kv_cache_dtype
)
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
,
layer
.
_q_scale
,
layer
.
_k_scale
,
self
.
kv_cache_dtype
)
return
output
return
output
\ No newline at end of file
vllm/attention/ops/flashmla.py
View file @
347fc09c
...
@@ -69,6 +69,27 @@ def get_mla_metadata(
...
@@ -69,6 +69,27 @@ def get_mla_metadata(
num_heads_k
)
num_heads_k
)
def
get_mla_decoding_metadata_dense_fp8
(
cache_seqlens
:
torch
.
Tensor
,
num_heads_per_head_k
:
int
,
num_heads_k
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Return:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return
flash_mla_cuda
.
get_mla_decoding_metadata_dense_fp8
(
cache_seqlens
,
num_heads_per_head_k
,
num_heads_k
)
def
flash_mla_with_kvcache
(
def
flash_mla_with_kvcache
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
...
@@ -199,6 +220,59 @@ def flash_mla_with_kvcache_q_nope_pe(
...
@@ -199,6 +220,59 @@ def flash_mla_with_kvcache_q_nope_pe(
return
out
,
softmax_lse
return
out
,
softmax_lse
def
flash_mla_with_kvcache_fp8
(
q
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
head_dim_v
:
int
,
tile_scheduler_metadata
:
torch
.
Tensor
,
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
descale_q
:
Optional
[
torch
.
Tensor
]
=
None
,
descale_k
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head_dim of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
torch.int32, return by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, return by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla_fp8
(
q
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
descale_q
,
descale_k
,
)
return
out
,
softmax_lse
#
#
# TODO: Add fake functions
# TODO: Add fake functions
#
#
...
...
vllm/envs.py
View file @
347fc09c
...
@@ -146,6 +146,7 @@ if TYPE_CHECKING:
...
@@ -146,6 +146,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_OPT_MLA
:
bool
=
False
VLLM_USE_TRITON_OPT_MLA
:
bool
=
False
VLLM_USE_FLASH_MLA
:
bool
=
False
VLLM_USE_FLASH_MLA
:
bool
=
False
VLLM_USE_FLASH_MLA_FP8
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_TC_PAGED_ATTN
:
bool
=
False
VLLM_USE_TC_PAGED_ATTN
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
...
@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_MLA"
:
"VLLM_USE_FLASH_MLA"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_MLA"
,
"1"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_MLA"
,
"1"
))),
# If set, vLLM will use FLASH MLA fp8 attention optimizations.
"VLLM_USE_FLASH_MLA_FP8"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_MLA_FP8"
,
"0"
))),
# flag to control vllm to use optimized kernels
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP"
:
"VLLM_USE_OPT_OP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_OP"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_OP"
,
"True"
).
lower
()
in
...
...
vllm/model_executor/model_loader/utils.py
View file @
347fc09c
...
@@ -255,8 +255,8 @@ def get_model_architecture(
...
@@ -255,8 +255,8 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_REJECT_SAMPLE_OPT"
):
if
not
envs
.
is_set
(
"VLLM_REJECT_SAMPLE_OPT"
):
os
.
environ
[
'VLLM_REJECT_SAMPLE_OPT'
]
=
'1'
os
.
environ
[
'VLLM_REJECT_SAMPLE_OPT'
]
=
'1'
#
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
if
not
envs
.
is_set
(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"
):
#
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
os
.
environ
[
'VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
):
if
not
envs
.
is_set
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
):
os
.
environ
[
'VLLM_SCHED_ENABLE_MINIMAL_INJECTION'
]
=
'1'
os
.
environ
[
'VLLM_SCHED_ENABLE_MINIMAL_INJECTION'
]
=
'1'
if
model_config
.
quantization
in
{
"slimquant_w4a8"
,
"slimquant_w4a8_marlin"
,
"slimquant_compressed_tensors_marlin"
,
"compressed-tensors"
}:
if
model_config
.
quantization
in
{
"slimquant_w4a8"
,
"slimquant_w4a8_marlin"
,
"slimquant_compressed_tensors_marlin"
,
"compressed-tensors"
}:
...
@@ -300,8 +300,8 @@ def get_model_architecture(
...
@@ -300,8 +300,8 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
os
.
environ
[
'VLLM_USE_CAT_MLA'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_REJECT_SAMPLE_OPT"
):
if
not
envs
.
is_set
(
"VLLM_REJECT_SAMPLE_OPT"
):
os
.
environ
[
'VLLM_REJECT_SAMPLE_OPT'
]
=
'1'
os
.
environ
[
'VLLM_REJECT_SAMPLE_OPT'
]
=
'1'
#
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
if
not
envs
.
is_set
(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"
):
#
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
os
.
environ
[
'VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
):
if
not
envs
.
is_set
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
):
os
.
environ
[
'VLLM_SCHED_ENABLE_MINIMAL_INJECTION'
]
=
'1'
os
.
environ
[
'VLLM_SCHED_ENABLE_MINIMAL_INJECTION'
]
=
'1'
if
model_config
.
quantization
in
{
"slimquant_w4a8"
,
"slimquant_w4a8_marlin"
,
"slimquant_compressed_tensors_marlin"
,
"compressed-tensors"
}:
if
model_config
.
quantization
in
{
"slimquant_w4a8"
,
"slimquant_w4a8_marlin"
,
"slimquant_compressed_tensors_marlin"
,
"compressed-tensors"
}:
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
347fc09c
This diff is collapsed.
Click to expand it.
vllm/v1/attention/backends/mla/common.py
View file @
347fc09c
...
@@ -1095,6 +1095,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1095,6 +1095,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata
:
M
,
attn_metadata
:
M
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
query_nope
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_heads
:
Optional
[
int
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
q_ori
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
key_normed
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
positions
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1154,7 +1156,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1154,7 +1156,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale
=
layer
.
_k_scale
,
scale
=
layer
.
_k_scale
,
)
)
else
:
else
:
from
lightop
import
fused_rms_norm_rope_contiguous
if
self
.
kv_cache_dtype
==
"auto"
:
if
self
.
kv_cache_dtype
==
"auto"
:
if
q
.
dtype
==
torch
.
float16
:
if
q
.
dtype
==
torch
.
float16
:
kv_cache_dtype_str
=
"fp16"
kv_cache_dtype_str
=
"fp16"
...
@@ -1162,7 +1163,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1162,7 +1163,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str
=
"bf16"
kv_cache_dtype_str
=
"bf16"
else
:
else
:
kv_cache_dtype_str
=
self
.
kv_cache_dtype
kv_cache_dtype_str
=
self
.
kv_cache_dtype
from
lightop
import
fused_rms_norm_rope_contiguous
fused_rms_norm_rope_contiguous
(
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
positions
[:
num_actual_toks
,
...],
q
,
q
,
...
@@ -1199,6 +1200,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1199,6 +1200,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
output
[:
num_decode_tokens
]
=
self
.
_forward_decode
(
output
[:
num_decode_tokens
]
=
self
.
_forward_decode
(
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
,
layer
.
_k_scale
,
self
.
kv_cache_dtype
)
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
,
layer
.
_q_scale
,
layer
.
_k_scale
,
self
.
kv_cache_dtype
)
return
output_padded
return
output_padded
\ No newline at end of file
vllm/v1/attention/backends/mla/flashmla.py
View file @
347fc09c
...
@@ -11,6 +11,8 @@ from vllm.attention.backends.abstract import (AttentionType,
...
@@ -11,6 +11,8 @@ from vllm.attention.backends.abstract import (AttentionType,
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
flash_mla_with_kvcache_q_nope_pe
,
flash_mla_with_kvcache_q_nope_pe
,
get_mla_metadata
,
get_mla_metadata
,
flash_mla_with_kvcache_fp8
,
get_mla_decoding_metadata_dense_fp8
,
is_flashmla_supported
)
is_flashmla_supported
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
...
@@ -162,12 +164,42 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -162,12 +164,42 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashMLAMetadata
,
attn_metadata
:
FlashMLAMetadata
,
q_scale
=
None
,
k_scale
=
None
,
k_scale
=
None
,
kv_cache_dtype
=
"auto"
,
kv_cache_dtype
=
"auto"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
assert
attn_metadata
.
decode
is
not
None
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
q_nope
.
shape
[
0
]
<
1024
:
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
q
=
concat_helper_decode
(
q_nope
,
q_pe
,
dim
=
2
)
\
.
unsqueeze
(
1
)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
.
to
(
torch
.
float8_e4m3fn
),
k_cache
=
kv_c_and_k_pe_cache
.
view
(
torch
.
float8_e4m3fn
).
unsqueeze
(
-
2
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
q_scale
,
descale_k
=
k_scale
,
)
else
:
if
not
envs
.
VLLM_USE_CAT_MLA
or
kv_cache_dtype
==
"fp8_e4m3"
:
if
not
envs
.
VLLM_USE_CAT_MLA
or
kv_cache_dtype
==
"fp8_e4m3"
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
q_nope
.
shape
[
0
]
<
1024
:
if
q_nope
.
shape
[
0
]
<
1024
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment