Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3b9aa746
Commit
3b9aa746
authored
Mar 11, 2026
by
zhangqha
Browse files
Merge branch 'v0.15.1-dev' into 'v0.15.1-dev-lxh'
# Conflicts: # vllm/model_executor/layers/fused_moe/fused_moe.py
parents
03a3c522
02a1e691
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
65 additions
and
58 deletions
+65
-58
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+2
-2
vllm/envs.py
vllm/envs.py
+1
-1
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+3
-2
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+2
-2
vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
...l_executor/layers/fused_moe/router/grouped_topk_router.py
+4
-2
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+3
-3
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+0
-26
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+1
-1
vllm/model_executor/layers/quantization/slimquant_w4a8.py
vllm/model_executor/layers/quantization/slimquant_w4a8.py
+2
-0
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+2
-0
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+1
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+6
-2
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+4
-0
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+1
-0
vllm/v1/attention/ops/flashmla.py
vllm/v1/attention/ops/flashmla.py
+33
-16
No files found.
vllm/_aiter_ops.py
View file @
3b9aa746
...
@@ -921,12 +921,12 @@ class rocm_aiter_ops:
...
@@ -921,12 +921,12 @@ class rocm_aiter_ops:
return
cls
.
_AITER_ENABLED
and
cls
.
_RMSNORM_ENABLED
return
cls
.
_AITER_ENABLED
and
cls
.
_RMSNORM_ENABLED
@
classmethod
@
classmethod
@
if_aiter_supported
#
@if_aiter_supported
def
is_fused_moe_enabled
(
cls
)
->
bool
:
def
is_fused_moe_enabled
(
cls
)
->
bool
:
return
cls
.
_AITER_ENABLED
and
cls
.
_FMOE_ENABLED
return
cls
.
_AITER_ENABLED
and
cls
.
_FMOE_ENABLED
@
classmethod
@
classmethod
@
if_aiter_supported
#
@if_aiter_supported
def
is_fusion_moe_shared_experts_enabled
(
cls
)
->
bool
:
def
is_fusion_moe_shared_experts_enabled
(
cls
)
->
bool
:
return
cls
.
is_fused_moe_enabled
()
and
cls
.
_MOE_SHARED_EXPERTS_ENABLED
return
cls
.
is_fused_moe_enabled
()
and
cls
.
_MOE_SHARED_EXPERTS_ENABLED
...
...
vllm/envs.py
View file @
3b9aa746
...
@@ -1055,7 +1055,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1055,7 +1055,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Whether to use aiter triton fp4 bmm kernel
# Whether to use aiter triton fp4 bmm kernel
# By default is enabled.
# By default is enabled.
"VLLM_ROCM_USE_AITER_FP4BMM"
:
lambda
:
(
"VLLM_ROCM_USE_AITER_FP4BMM"
:
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER_FP4BMM"
,
"
Tru
e"
).
lower
()
in
(
"true"
,
"1"
)
os
.
getenv
(
"VLLM_ROCM_USE_AITER_FP4BMM"
,
"
Fals
e"
).
lower
()
in
(
"true"
,
"1"
)
),
),
# Use AITER triton unified attention for V1 attention
# Use AITER triton unified attention for V1 attention
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION"
:
lambda
:
(
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION"
:
lambda
:
(
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
3b9aa746
...
@@ -215,6 +215,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -215,6 +215,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_and_maybe_dequant_weights
,
get_and_maybe_dequant_weights
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.rocm
import
get_gcn_arch_name
from
vllm.utils.flashinfer
import
has_nvidia_artifactory
from
vllm.utils.flashinfer
import
has_nvidia_artifactory
from
vllm.utils.math_utils
import
cdiv
,
round_down
from
vllm.utils.math_utils
import
cdiv
,
round_down
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
...
@@ -2115,7 +2116,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -2115,7 +2116,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
scale
=
layer
.
_k_scale
,
scale
=
layer
.
_k_scale
,
)
)
if
fp8_attention
:
if
fp8_attention
and
get_gcn_arch_name
()
==
"gfx938"
:
kv_cache
=
kv_cache
.
view
(
current_platform
.
fp8_dtype
())
kv_cache
=
kv_cache
.
view
(
current_platform
.
fp8_dtype
())
if
has_prefill
:
if
has_prefill
:
...
@@ -2185,7 +2186,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -2185,7 +2186,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# Convert from (N, B, L) to (B, N, L)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
if
fp8_attention
:
if
fp8_attention
and
get_gcn_arch_name
()
==
"gfx938"
:
assert
decode_ql_nope
.
shape
[
0
]
==
decode_q_pe
.
shape
[
0
]
assert
decode_ql_nope
.
shape
[
0
]
==
decode_q_pe
.
shape
[
0
]
assert
decode_ql_nope
.
shape
[
1
]
==
decode_q_pe
.
shape
[
1
]
assert
decode_ql_nope
.
shape
[
1
]
==
decode_q_pe
.
shape
[
1
]
decode_q
=
self
.
_decode_concat_quant_fp8_op
(
decode_q
=
self
.
_decode_concat_quant_fp8_op
(
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
3b9aa746
...
@@ -1613,8 +1613,8 @@ def fused_experts(
...
@@ -1613,8 +1613,8 @@ def fused_experts(
quant_config
:
FusedMoEQuantConfig
|
None
=
None
,
quant_config
:
FusedMoEQuantConfig
|
None
=
None
,
use_nn_moe
:
bool
|
None
=
False
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
i_s
:
torch
.
Tensor
|
None
=
None
# TODO:wjl
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
quant_config
is
None
:
if
quant_config
is
None
:
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
...
...
vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
View file @
3b9aa746
...
@@ -335,8 +335,10 @@ class GroupedTopKRouter(BaseRouter):
...
@@ -335,8 +335,10 @@ class GroupedTopKRouter(BaseRouter):
rocm_aiter_grouped_topk
,
rocm_aiter_grouped_topk
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
)
)
enable_shared_experts_fusion
=
True
else
:
else
:
grouped_topk_impl
=
grouped_topk
grouped_topk_impl
=
grouped_topk
enable_shared_experts_fusion
=
False
if
self
.
use_fused_gate
:
if
self
.
use_fused_gate
:
if
envs
.
VLLM_USE_LIGHTOP
:
if
envs
.
VLLM_USE_LIGHTOP
:
...
@@ -347,7 +349,7 @@ class GroupedTopKRouter(BaseRouter):
...
@@ -347,7 +349,7 @@ class GroupedTopKRouter(BaseRouter):
self
.
num_expert_group
,
self
.
num_expert_group
,
self
.
topk_group
,
self
.
topk_group
,
self
.
top_k
,
self
.
top_k
,
0
,
self
.
num_fused_shared_experts
if
enable_shared_experts_fusion
else
0
,
self
.
routed_scaling_factor
,
self
.
routed_scaling_factor
,
)
)
else
:
else
:
...
@@ -358,7 +360,7 @@ class GroupedTopKRouter(BaseRouter):
...
@@ -358,7 +360,7 @@ class GroupedTopKRouter(BaseRouter):
self
.
topk_group
,
self
.
topk_group
,
self
.
top_k
,
self
.
top_k
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
n_share_experts_fusion
=
0
,
n_share_experts_fusion
=
(
self
.
num_fused_shared_experts
if
enable_shared_experts_fusion
else
0
)
,
)
)
else
:
else
:
topk_weights
,
topk_ids
=
grouped_topk_impl
(
topk_weights
,
topk_ids
=
grouped_topk_impl
(
...
...
vllm/model_executor/layers/layernorm.py
View file @
3b9aa746
...
@@ -335,7 +335,7 @@ class FusedRMSNormQuant(nn.Module):
...
@@ -335,7 +335,7 @@ class FusedRMSNormQuant(nn.Module):
quant_dtype
:
torch
.
dtype
=
torch
.
int8
,
quant_dtype
:
torch
.
dtype
=
torch
.
int8
,
update_input
:
Optional
[
bool
]
=
True
update_input
:
Optional
[
bool
]
=
True
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
i_q
,
i_s
=
torch
.
ops
.
vllm
.
fused_rmsquant
(
input
=
x
,
i_q
,
i_s
=
torch
.
ops
.
vllm
.
fused_rmsquant
_customer_impl
(
input
=
x
,
weight
=
self
.
weight
,
weight
=
self
.
weight
,
epsilon
=
self
.
variance_epsilon
,
epsilon
=
self
.
variance_epsilon
,
quant_dtype
=
quant_dtype
,
quant_dtype
=
quant_dtype
,
...
@@ -383,9 +383,9 @@ def fused_rmsquant_fake(
...
@@ -383,9 +383,9 @@ def fused_rmsquant_fake(
# customer_lib = Library("customer_", "FRAGMENT")
# customer_lib = Library("customer_", "FRAGMENT")
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"fused_rmsquant"
,
op_name
=
"fused_rmsquant
_customer_impl
"
,
op_func
=
fused_rmsquant_impl
,
op_func
=
fused_rmsquant_impl
,
mutates_args
=
[],
mutates_args
=
[
"input"
,
"residual"
],
fake_impl
=
fused_rmsquant_fake
,
fake_impl
=
fused_rmsquant_fake
,
)
)
...
...
vllm/model_executor/layers/linear.py
View file @
3b9aa746
...
@@ -711,32 +711,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -711,32 +711,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
disable_tp: If true, all weights matrix won't be sharded, this layer
disable_tp: If true, all weights matrix won't be sharded, this layer
will be treated as a "Replicated" MergedLinear.
will be treated as a "Replicated" MergedLinear.
"""
"""
def
forward
(
self
,
input_
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
Parameter
|
None
]:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
assert
self
.
quant_method
is
not
None
if
envs
.
USE_FUSED_RMS_QUANT
and
iqis
is
not
None
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
=
iqis
)
else
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
if
self
.
gather_output
and
self
.
tp_size
>
1
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
if
not
self
.
return_bias
:
return
output
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
def
__init__
(
def
__init__
(
self
,
self
,
input_size
:
int
,
input_size
:
int
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
3b9aa746
...
@@ -1256,7 +1256,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -1256,7 +1256,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
use_nn_moe
:
bool
|
None
=
False
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_s
:
torch
.
Tensor
|
None
=
None
,
i_s
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8.py
View file @
3b9aa746
...
@@ -307,6 +307,8 @@ class SlimQuantW4A8Int8MoEMethod:
...
@@ -307,6 +307,8 @@ class SlimQuantW4A8Int8MoEMethod:
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_s
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
return
fused_experts
(
return
fused_experts
(
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
View file @
3b9aa746
...
@@ -224,6 +224,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -224,6 +224,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_s
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
3b9aa746
...
@@ -49,7 +49,7 @@ def sparse_attn_indexer(
...
@@ -49,7 +49,7 @@ def sparse_attn_indexer(
if
not
isinstance
(
attn_metadata
,
dict
):
if
not
isinstance
(
attn_metadata
,
dict
):
# Reserve workspace for indexer during profiling run
# Reserve workspace for indexer during profiling run
current_workspace_manager
().
get_simultaneous
(
current_workspace_manager
().
get_simultaneous
(
((
total_seq_lens
,
head_dim
),
fp8_dtype
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
else
torch
.
bfloat16
),
((
total_seq_lens
,
head_dim
),
fp8_dtype
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
else
k
.
dtype
,
),
((
total_seq_lens
,
4
),
torch
.
uint8
),
((
total_seq_lens
,
4
),
torch
.
uint8
),
)
)
return
sparse_attn_indexer_fake
(
return
sparse_attn_indexer_fake
(
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
3b9aa746
...
@@ -324,8 +324,12 @@ class DeepseekV2MoE(nn.Module):
...
@@ -324,8 +324,12 @@ class DeepseekV2MoE(nn.Module):
self
.
experts
=
SharedFusedMoE
(
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_experts
,
shared_experts
=
self
.
shared_experts
,
gate
=
self
.
gate
,
gate
=
self
.
gate
,
num_experts
=
config
.
n_routed_experts
,
# num_experts=config.n_routed_experts,
top_k
=
config
.
num_experts_per_tok
,
# top_k=config.num_experts_per_tok,
num_experts
=
config
.
n_routed_experts
+
(
config
.
n_shared_experts
if
self
.
is_fusion_moe_shared_experts_enabled
else
0
),
top_k
=
config
.
num_experts_per_tok
+
(
config
.
n_shared_experts
if
self
.
is_fusion_moe_shared_experts_enabled
else
0
),
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
reduce_results
=
False
,
...
...
vllm/platforms/rocm.py
View file @
3b9aa746
...
@@ -121,6 +121,10 @@ def on_gfx9() -> bool:
...
@@ -121,6 +121,10 @@ def on_gfx9() -> bool:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
,
"gfx950"
,
"gfx928"
,
"gfx936"
,
"gfx938"
])
return
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
,
"gfx950"
,
"gfx928"
,
"gfx936"
,
"gfx938"
])
@
cache
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
@
cache
@
cache
def
on_gfx942
()
->
bool
:
def
on_gfx942
()
->
bool
:
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
3b9aa746
...
@@ -310,6 +310,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -310,6 +310,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
causal
=
True
,
causal
=
True
,
descale_q
=
layer
.
_q_scale
.
reshape
(
1
),
descale_q
=
layer
.
_q_scale
.
reshape
(
1
),
descale_k
=
layer
.
_k_scale
.
reshape
(
1
),
descale_k
=
layer
.
_k_scale
.
reshape
(
1
),
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
)
else
:
else
:
o
,
lse
=
flash_mla_with_kvcache
(
o
,
lse
=
flash_mla_with_kvcache
(
...
...
vllm/v1/attention/ops/flashmla.py
View file @
3b9aa746
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.rocm
import
get_gcn_arch_name
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
if
current_platform
.
is_cuda
():
if
current_platform
.
is_cuda
():
...
@@ -136,7 +136,7 @@ def get_mla_metadata_dense_fp8(
...
@@ -136,7 +136,7 @@ def get_mla_metadata_dense_fp8(
cache_seqlens
,
cache_seqlens
,
num_q_tokens_per_head_k
,
num_q_tokens_per_head_k
,
num_heads_k
,
num_heads_k
,
16
,
#
16,
)
)
else
:
else
:
return
torch
.
ops
.
_flashmla_extension_C
.
get_mla_decoding_metadata_dense_fp8
(
return
torch
.
ops
.
_flashmla_extension_C
.
get_mla_decoding_metadata_dense_fp8
(
...
@@ -158,26 +158,43 @@ def flash_mla_with_kvcache_fp8(
...
@@ -158,26 +158,43 @@ def flash_mla_with_kvcache_fp8(
causal
:
bool
=
False
,
causal
:
bool
=
False
,
descale_q
:
torch
.
Tensor
|
None
=
None
,
descale_q
:
torch
.
Tensor
|
None
=
None
,
descale_k
:
torch
.
Tensor
|
None
=
None
,
descale_k
:
torch
.
Tensor
|
None
=
None
,
kv_cache_dtype
:
str
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
_is_flashmla_available
()[
0
]:
if
not
_is_flashmla_available
()[
0
]:
_raise_flashmla_unavailable
()
_raise_flashmla_unavailable
()
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla_fp8
(
if
get_gcn_arch_name
()
==
"gfx938"
:
q
,
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla_fp8
(
k_cache
,
q
,
None
,
k_cache
,
head_dim_v
,
None
,
cache_seqlens
,
head_dim_v
,
block_table
,
cache_seqlens
,
softmax_scale
,
block_table
,
causal
,
softmax_scale
,
tile_scheduler_metadata
,
causal
,
num_splits
,
tile_scheduler_metadata
,
descale_q
,
num_splits
,
descale_k
,
descale_q
,
)
descale_k
,
)
else
:
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_quantization_mla
(
q
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
descale_k
,
kv_cache_dtype
,
)
else
:
else
:
out
,
softmax_lse
=
torch
.
ops
.
_flashmla_extension_C
.
fwd_kvcache_mla_fp8
(
out
,
softmax_lse
=
torch
.
ops
.
_flashmla_extension_C
.
fwd_kvcache_mla_fp8
(
q
,
q
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment