Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
49a12622
Unverified
Commit
49a12622
authored
Jan 22, 2026
by
Alex Sun
Committed by
GitHub
Jan 22, 2026
Browse files
[AMD][ROCm] MoRI EP: a high-performance all2all backend (#28664)
Signed-off-by:
Alex Sun
<
alex.s@amd.com
>
parent
2b8a38b6
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
397 additions
and
9 deletions
+397
-9
tests/kernels/moe/modular_kernel_tools/cli_args.py
tests/kernels/moe/modular_kernel_tools/cli_args.py
+1
-1
tests/kernels/moe/modular_kernel_tools/common.py
tests/kernels/moe/modular_kernel_tools/common.py
+19
-1
tests/kernels/moe/modular_kernel_tools/mk_objects.py
tests/kernels/moe/modular_kernel_tools/mk_objects.py
+48
-1
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+12
-0
vllm/config/parallel.py
vllm/config/parallel.py
+3
-0
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+94
-1
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+4
-0
vllm/envs.py
vllm/envs.py
+3
-0
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+4
-0
vllm/model_executor/layers/fused_moe/all2all_utils.py
vllm/model_executor/layers/fused_moe/all2all_utils.py
+34
-1
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+8
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+13
-0
vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py
.../model_executor/layers/fused_moe/mori_prepare_finalize.py
+121
-0
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+17
-4
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+6
-0
vllm/utils/import_utils.py
vllm/utils/import_utils.py
+10
-0
No files found.
tests/kernels/moe/modular_kernel_tools/cli_args.py
View file @
49a12622
...
@@ -141,7 +141,7 @@ def make_config(args: argparse.Namespace) -> Config:
...
@@ -141,7 +141,7 @@ def make_config(args: argparse.Namespace) -> Config:
quant_config
=
None
quant_config
=
None
if
args
.
quant_dtype
is
not
None
:
if
args
.
quant_dtype
is
not
None
:
quant_config
=
FusedMoEQuantConfig
(
quant_config
=
FusedMoEQuantConfig
.
make
(
quant_dtype
=
args
.
quant_dtype
,
quant_dtype
=
args
.
quant_dtype
,
per_act_token_quant
=
args
.
per_token_quantized_activations
,
per_act_token_quant
=
args
.
per_token_quantized_activations
,
per_out_ch_quant
=
args
.
per_channel_quantized_weights
,
per_out_ch_quant
=
args
.
per_channel_quantized_weights
,
...
...
tests/kernels/moe/modular_kernel_tools/common.py
View file @
49a12622
...
@@ -28,7 +28,13 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -28,7 +28,13 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
RoutingMethodType
,
)
)
from
vllm.utils.import_utils
import
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
vllm.utils.import_utils
import
(
has_aiter
,
has_deep_ep
,
has_deep_gemm
,
has_mori
,
has_pplx
,
)
from
.mk_objects
import
(
from
.mk_objects
import
(
TestMoEQuantConfig
,
TestMoEQuantConfig
,
...
@@ -211,6 +217,14 @@ class Config:
...
@@ -211,6 +217,14 @@ class Config:
or
info
.
backend
==
"deepep_low_latency"
or
info
.
backend
==
"deepep_low_latency"
)
)
def
needs_aiter
(
self
):
info
=
expert_info
(
self
.
fused_experts_type
)
return
info
.
needs_aiter
def
needs_mori
(
self
):
info
=
prepare_finalize_info
(
self
.
prepare_finalize_type
)
return
info
.
backend
==
"mori"
def
all2all_backend
(
self
):
def
all2all_backend
(
self
):
info
=
prepare_finalize_info
(
self
.
prepare_finalize_type
)
info
=
prepare_finalize_info
(
self
.
prepare_finalize_type
)
return
info
.
backend
return
info
.
backend
...
@@ -278,6 +292,10 @@ class Config:
...
@@ -278,6 +292,10 @@ class Config:
return
False
,
"Needs DeepGEMM, but DeepGEMM not available."
return
False
,
"Needs DeepGEMM, but DeepGEMM not available."
if
self
.
needs_pplx
()
and
not
has_pplx
():
# noqa: SIM103
if
self
.
needs_pplx
()
and
not
has_pplx
():
# noqa: SIM103
return
False
,
"Needs PPLX, but PPLX not available."
return
False
,
"Needs PPLX, but PPLX not available."
if
self
.
needs_aiter
()
and
not
has_aiter
():
# noqa: SIM103
return
False
,
"Needs Aiter, but Aiter not available."
if
self
.
needs_mori
()
and
not
has_mori
():
# noqa: SIM103
return
False
,
"Needs MoRI, but MoRI not available."
return
True
,
None
return
True
,
None
...
...
tests/kernels/moe/modular_kernel_tools/mk_objects.py
View file @
49a12622
...
@@ -37,7 +37,13 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...
@@ -37,7 +37,13 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
is_deep_gemm_supported
from
vllm.utils.deep_gemm
import
is_deep_gemm_supported
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
vllm.utils.import_utils
import
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
vllm.utils.import_utils
import
(
has_aiter
,
has_deep_ep
,
has_deep_gemm
,
has_mori
,
has_pplx
,
)
@
dataclass
@
dataclass
...
@@ -66,6 +72,7 @@ class ExpertInfo:
...
@@ -66,6 +72,7 @@ class ExpertInfo:
supports_expert_map
:
bool
supports_expert_map
:
bool
needs_matching_quant
:
bool
=
False
needs_matching_quant
:
bool
=
False
needs_deep_gemm
:
bool
=
False
needs_deep_gemm
:
bool
=
False
needs_aiter
:
bool
=
False
PREPARE_FINALIZE_INFO
:
dict
[
mk
.
FusedMoEPrepareAndFinalize
,
PrepareFinalizeInfo
]
=
{}
PREPARE_FINALIZE_INFO
:
dict
[
mk
.
FusedMoEPrepareAndFinalize
,
PrepareFinalizeInfo
]
=
{}
...
@@ -126,6 +133,7 @@ def register_experts(
...
@@ -126,6 +133,7 @@ def register_experts(
supports_expert_map
:
bool
,
supports_expert_map
:
bool
,
needs_matching_quant
:
bool
=
False
,
needs_matching_quant
:
bool
=
False
,
needs_deep_gemm
:
bool
=
False
,
needs_deep_gemm
:
bool
=
False
,
needs_aiter
:
bool
=
False
,
):
):
global
EXPERT_INFO
global
EXPERT_INFO
global
MK_FUSED_EXPERT_TYPES
global
MK_FUSED_EXPERT_TYPES
...
@@ -139,6 +147,7 @@ def register_experts(
...
@@ -139,6 +147,7 @@ def register_experts(
supports_expert_map
,
supports_expert_map
,
needs_matching_quant
,
needs_matching_quant
,
needs_deep_gemm
,
needs_deep_gemm
,
needs_aiter
,
)
)
MK_FUSED_EXPERT_TYPES
.
append
(
kind
)
MK_FUSED_EXPERT_TYPES
.
append
(
kind
)
...
@@ -218,6 +227,20 @@ if has_deep_ep() and not current_platform.has_device_capability(100):
...
@@ -218,6 +227,20 @@ if has_deep_ep() and not current_platform.has_device_capability(100):
backend
=
"deepep_low_latency"
,
backend
=
"deepep_low_latency"
,
)
)
if
has_mori
():
from
vllm.model_executor.layers.fused_moe.mori_prepare_finalize
import
(
MoriPrepareAndFinalize
,
)
register_prepare_and_finalize
(
MoriPrepareAndFinalize
,
standard_format
,
fp8_types
,
blocked_quantization_support
=
True
,
backend
=
"mori"
,
supports_apply_weight_on_input
=
False
,
)
if
has_pplx
():
if
has_pplx
():
from
vllm.model_executor.layers.fused_moe.pplx_prepare_finalize
import
(
from
vllm.model_executor.layers.fused_moe.pplx_prepare_finalize
import
(
PplxPrepareAndFinalize
,
PplxPrepareAndFinalize
,
...
@@ -261,6 +284,25 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
...
@@ -261,6 +284,25 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
)
)
else
:
else
:
FlashInferCutlassMoEPrepareAndFinalize
=
None
FlashInferCutlassMoEPrepareAndFinalize
=
None
FlashInferExperts
=
None
if
has_aiter
():
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
AiterExperts
,
)
register_experts
(
AiterExperts
,
standard_format
,
fp8_types
,
blocked_quantization_support
=
True
,
supports_chunking
=
True
,
supports_expert_map
=
True
,
needs_aiter
=
True
,
)
else
:
AiterExperts
=
None
if
has_deep_gemm
()
and
is_deep_gemm_supported
():
if
has_deep_gemm
()
and
is_deep_gemm_supported
():
register_experts
(
register_experts
(
...
@@ -316,6 +358,9 @@ if cutlass_fp8_supported():
...
@@ -316,6 +358,9 @@ if cutlass_fp8_supported():
supports_chunking
=
False
,
supports_chunking
=
False
,
supports_expert_map
=
False
,
supports_expert_map
=
False
,
)
)
else
:
CutlassBatchedExpertsFp8
=
None
CutlassExpertsFp8
=
None
if
cutlass_fp4_supported
():
if
cutlass_fp4_supported
():
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassExpertsFp4
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassExpertsFp4
...
@@ -328,6 +373,8 @@ if cutlass_fp4_supported():
...
@@ -328,6 +373,8 @@ if cutlass_fp4_supported():
supports_chunking
=
True
,
supports_chunking
=
True
,
supports_expert_map
=
False
,
supports_expert_map
=
False
,
)
)
else
:
CutlassExpertsFp4
=
None
MK_QUANT_CONFIGS
:
list
[
TestMoEQuantConfig
|
None
]
=
[
MK_QUANT_CONFIGS
:
list
[
TestMoEQuantConfig
|
None
]
=
[
None
,
None
,
...
...
vllm/_aiter_ops.py
View file @
49a12622
...
@@ -79,6 +79,8 @@ def _rocm_aiter_fused_moe_impl(
...
@@ -79,6 +79,8 @@ def _rocm_aiter_fused_moe_impl(
w2_scale
:
torch
.
Tensor
|
None
=
None
,
w2_scale
:
torch
.
Tensor
|
None
=
None
,
a1_scale
:
torch
.
Tensor
|
None
=
None
,
a1_scale
:
torch
.
Tensor
|
None
=
None
,
a2_scale
:
torch
.
Tensor
|
None
=
None
,
a2_scale
:
torch
.
Tensor
|
None
=
None
,
num_local_tokens
:
torch
.
Tensor
|
None
=
None
,
output_dtype
:
torch
.
dtype
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
aiter
import
ActivationType
,
QuantType
from
aiter
import
ActivationType
,
QuantType
from
aiter.fused_moe
import
fused_moe
from
aiter.fused_moe
import
fused_moe
...
@@ -100,6 +102,8 @@ def _rocm_aiter_fused_moe_impl(
...
@@ -100,6 +102,8 @@ def _rocm_aiter_fused_moe_impl(
w2_scale
,
w2_scale
,
a1_scale
,
a1_scale
,
a2_scale
,
a2_scale
,
num_local_tokens
=
num_local_tokens
,
dtype
=
output_dtype
,
)
)
...
@@ -117,7 +121,11 @@ def _rocm_aiter_fused_moe_fake(
...
@@ -117,7 +121,11 @@ def _rocm_aiter_fused_moe_fake(
w2_scale
:
torch
.
Tensor
|
None
=
None
,
w2_scale
:
torch
.
Tensor
|
None
=
None
,
a1_scale
:
torch
.
Tensor
|
None
=
None
,
a1_scale
:
torch
.
Tensor
|
None
=
None
,
a2_scale
:
torch
.
Tensor
|
None
=
None
,
a2_scale
:
torch
.
Tensor
|
None
=
None
,
num_local_tokens
:
torch
.
Tensor
|
None
=
None
,
output_dtype
:
torch
.
dtype
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
output_dtype
is
not
None
:
return
torch
.
empty_like
(
hidden_states
,
dtype
=
output_dtype
)
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
...
@@ -1236,6 +1244,8 @@ class rocm_aiter_ops:
...
@@ -1236,6 +1244,8 @@ class rocm_aiter_ops:
w2_scale
:
torch
.
Tensor
|
None
=
None
,
w2_scale
:
torch
.
Tensor
|
None
=
None
,
a1_scale
:
torch
.
Tensor
|
None
=
None
,
a1_scale
:
torch
.
Tensor
|
None
=
None
,
a2_scale
:
torch
.
Tensor
|
None
=
None
,
a2_scale
:
torch
.
Tensor
|
None
=
None
,
num_local_tokens
:
torch
.
Tensor
|
None
=
None
,
output_dtype
:
torch
.
dtype
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
rocm_aiter_fused_moe
(
return
torch
.
ops
.
vllm
.
rocm_aiter_fused_moe
(
hidden_states
,
hidden_states
,
...
@@ -1251,6 +1261,8 @@ class rocm_aiter_ops:
...
@@ -1251,6 +1261,8 @@ class rocm_aiter_ops:
w2_scale
,
w2_scale
,
a1_scale
,
a1_scale
,
a2_scale
,
a2_scale
,
num_local_tokens
,
output_dtype
,
)
)
@
staticmethod
@
staticmethod
...
...
vllm/config/parallel.py
View file @
49a12622
...
@@ -43,6 +43,7 @@ All2AllBackend = Literal[
...
@@ -43,6 +43,7 @@ All2AllBackend = Literal[
"pplx"
,
"pplx"
,
"deepep_high_throughput"
,
"deepep_high_throughput"
,
"deepep_low_latency"
,
"deepep_low_latency"
,
"mori"
,
"allgather_reducescatter"
,
"allgather_reducescatter"
,
"flashinfer_all2allv"
,
"flashinfer_all2allv"
,
]
]
...
@@ -158,6 +159,7 @@ class ParallelConfig:
...
@@ -158,6 +159,7 @@ class ParallelConfig:
- "pplx": Use pplx kernels
\n
- "pplx": Use pplx kernels
\n
- "deepep_high_throughput": Use deepep high-throughput kernels
\n
- "deepep_high_throughput": Use deepep high-throughput kernels
\n
- "deepep_low_latency": Use deepep low-latency kernels
\n
- "deepep_low_latency": Use deepep low-latency kernels
\n
- "mori": Use mori kernels
\n
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
max_parallel_loading_workers
:
int
|
None
=
None
max_parallel_loading_workers
:
int
|
None
=
None
...
@@ -443,6 +445,7 @@ class ParallelConfig:
...
@@ -443,6 +445,7 @@ class ParallelConfig:
"naive"
,
"naive"
,
"deepep_high_throughput"
,
"deepep_high_throughput"
,
"deepep_low_latency"
,
"deepep_low_latency"
,
"mori"
,
)
)
and
self
.
enable_expert_parallel
and
self
.
enable_expert_parallel
and
self
.
tensor_parallel_size
>
1
and
self
.
tensor_parallel_size
>
1
...
...
vllm/distributed/device_communicators/all2all.py
View file @
49a12622
...
@@ -10,7 +10,7 @@ from vllm.distributed import get_dp_group, get_ep_group
...
@@ -10,7 +10,7 @@ from vllm.distributed import get_dp_group, get_ep_group
from
vllm.forward_context
import
get_forward_context
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils.flashinfer
import
has_flashinfer_all2all
from
vllm.utils.flashinfer
import
has_flashinfer_all2all
from
vllm.utils.import_utils
import
has_deep_ep
,
has_pplx
from
vllm.utils.import_utils
import
has_deep_ep
,
has_mori
,
has_pplx
from
.base_device_communicator
import
All2AllManagerBase
,
Cache
from
.base_device_communicator
import
All2AllManagerBase
,
Cache
...
@@ -507,3 +507,96 @@ class FlashInferAllToAllManager(All2AllManagerBase):
...
@@ -507,3 +507,96 @@ class FlashInferAllToAllManager(All2AllManagerBase):
self
.
prepare_workspace_tensor
=
None
self
.
prepare_workspace_tensor
=
None
self
.
mapping
=
None
self
.
mapping
=
None
self
.
initialized
=
False
self
.
initialized
=
False
class
MoriAll2AllManager
(
All2AllManagerBase
):
def
__init__
(
self
,
cpu_group
):
assert
has_mori
(),
(
"MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md"
" to install MoRI kernels."
)
# noqa
import
mori
super
().
__init__
(
cpu_group
)
self
.
handle_cache
=
Cache
()
torch
.
_C
.
_distributed_c10d
.
_register_process_group
(
"mori"
,
cpu_group
)
mori
.
shmem
.
shmem_torch_process_group_init
(
"mori"
)
def
_make_all2all_kwargs
(
self
,
rank
:
int
,
num_ep_ranks
:
int
,
input_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
token_hidden_size
:
int
,
scale_dim
:
int
,
scale_type_size
:
int
,
max_num_tokens_per_dp_rank
:
int
,
num_local_experts
:
int
,
num_experts_per_token
:
int
,
):
import
mori
# type: ignore[import-not-found]
from
vllm.platforms.rocm
import
on_gfx942
,
on_gfx950
assert
on_gfx942
()
or
on_gfx950
(),
(
"mori currently only support arch gfx942 and gfx950"
)
if
not
self
.
internode
:
# single node
kernel_type
=
mori
.
ops
.
EpDispatchCombineKernelType
.
IntraNode
rdma_block_num
=
0
warp_num_per_block
=
16
block_num
=
80
else
:
# multi node
kernel_type
=
mori
.
ops
.
EpDispatchCombineKernelType
.
InterNodeV1
if
on_gfx942
():
warp_num_per_block
=
16
block_num
=
32
rdma_block_num
=
16
elif
on_gfx950
():
warp_num_per_block
=
8
block_num
=
64
rdma_block_num
=
32
else
:
raise
NotImplementedError
(
"mori currently only support arch gfx942 and gfx950"
)
return
dict
(
rank
=
rank
,
world_size
=
num_ep_ranks
,
data_type
=
quant_dtype
,
hidden_dim
=
token_hidden_size
,
scale_dim
=
scale_dim
,
scale_type_size
=
scale_type_size
,
max_token_type_size
=
input_dtype
.
itemsize
,
max_num_inp_token_per_rank
=
max_num_tokens_per_dp_rank
,
num_experts_per_rank
=
num_local_experts
,
num_experts_per_token
=
num_experts_per_token
,
warp_num_per_block
=
warp_num_per_block
,
block_num
=
block_num
,
kernel_type
=
kernel_type
,
rdma_block_num
=
rdma_block_num
,
gpu_per_node
=
min
(
8
,
num_ep_ranks
),
)
def
_make_handle
(
self
,
**
kwargs
):
import
mori
# type: ignore[import-not-found]
mori_config
=
mori
.
ops
.
EpDispatchCombineConfig
(
**
kwargs
)
handle
=
mori
.
ops
.
EpDispatchCombineOp
(
mori_config
)
return
handle
def
get_handle
(
self
,
kwargs
):
import
mori
# type: ignore[import-not-found]
mori_kwargs
=
self
.
_make_all2all_kwargs
(
**
kwargs
)
logger
.
debug
(
"MoRI all2all args %s"
,
mori_kwargs
)
handle
:
mori
.
ops
.
EpDispatchCombineOp
=
self
.
handle_cache
.
get_or_create
(
mori_kwargs
,
self
.
_make_handle
)
return
handle
vllm/distributed/device_communicators/cuda_communicator.py
View file @
49a12622
...
@@ -110,6 +110,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -110,6 +110,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
from
.all2all
import
DeepEPLLAll2AllManager
from
.all2all
import
DeepEPLLAll2AllManager
self
.
all2all_manager
=
DeepEPLLAll2AllManager
(
self
.
cpu_group
)
self
.
all2all_manager
=
DeepEPLLAll2AllManager
(
self
.
cpu_group
)
elif
self
.
all2all_backend
==
"mori"
:
from
.all2all
import
MoriAll2AllManager
self
.
all2all_manager
=
MoriAll2AllManager
(
self
.
cpu_group
)
elif
self
.
all2all_backend
==
"flashinfer_all2allv"
:
elif
self
.
all2all_backend
==
"flashinfer_all2allv"
:
from
.all2all
import
FlashInferAllToAllManager
from
.all2all
import
FlashInferAllToAllManager
...
...
vllm/envs.py
View file @
49a12622
...
@@ -187,6 +187,7 @@ if TYPE_CHECKING:
...
@@ -187,6 +187,7 @@ if TYPE_CHECKING:
"pplx"
,
"pplx"
,
"deepep_high_throughput"
,
"deepep_high_throughput"
,
"deepep_low_latency"
,
"deepep_low_latency"
,
"mori"
,
"allgather_reducescatter"
,
"allgather_reducescatter"
,
"flashinfer_all2allv"
,
"flashinfer_all2allv"
,
]
=
"allgather_reducescatter"
]
=
"allgather_reducescatter"
...
@@ -1298,6 +1299,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1298,6 +1299,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "pplx": use pplx kernels
# - "pplx": use pplx kernels
# - "deepep_high_throughput", use deepep high-throughput kernels
# - "deepep_high_throughput", use deepep high-throughput kernels
# - "deepep_low_latency", use deepep low-latency kernels
# - "deepep_low_latency", use deepep low-latency kernels
# - "mori", use MoRI kernels
# - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl
# - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl
"VLLM_ALL2ALL_BACKEND"
:
env_with_choices
(
"VLLM_ALL2ALL_BACKEND"
:
env_with_choices
(
"VLLM_ALL2ALL_BACKEND"
,
"VLLM_ALL2ALL_BACKEND"
,
...
@@ -1307,6 +1309,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1307,6 +1309,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"pplx"
,
"pplx"
,
"deepep_high_throughput"
,
"deepep_high_throughput"
,
"deepep_low_latency"
,
"deepep_low_latency"
,
"mori"
,
"allgather_reducescatter"
,
"allgather_reducescatter"
,
"flashinfer_all2allv"
,
"flashinfer_all2allv"
,
],
],
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
49a12622
...
@@ -88,6 +88,9 @@ if HAS_TRITON:
...
@@ -88,6 +88,9 @@ if HAS_TRITON:
fused_experts
,
fused_experts
,
get_config_file_name
,
get_config_file_name
,
)
)
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
AiterExperts
,
)
from
vllm.model_executor.layers.fused_moe.router.fused_topk_router
import
(
from
vllm.model_executor.layers.fused_moe.router.fused_topk_router
import
(
fused_topk
,
fused_topk
,
)
)
...
@@ -99,6 +102,7 @@ if HAS_TRITON:
...
@@ -99,6 +102,7 @@ if HAS_TRITON:
)
)
__all__
+=
[
__all__
+=
[
"AiterExperts"
,
"fused_topk"
,
"fused_topk"
,
"fused_experts"
,
"fused_experts"
,
"get_config_file_name"
,
"get_config_file_name"
,
...
...
vllm/model_executor/layers/fused_moe/all2all_utils.py
View file @
49a12622
...
@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
...
@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize
,
FusedMoEPrepareAndFinalize
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.import_utils
import
has_deep_ep
,
has_pplx
from
vllm.utils.import_utils
import
has_deep_ep
,
has_mori
,
has_pplx
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
():
if
has_pplx
():
if
has_pplx
():
...
@@ -30,6 +30,8 @@ if current_platform.is_cuda_alike():
...
@@ -30,6 +30,8 @@ if current_platform.is_cuda_alike():
DEEPEP_QUANT_BLOCK_SHAPE
,
DEEPEP_QUANT_BLOCK_SHAPE
,
DeepEPLLPrepareAndFinalize
,
DeepEPLLPrepareAndFinalize
,
)
)
if
has_mori
():
from
.mori_prepare_finalize
import
MoriPrepareAndFinalize
def
maybe_roundup_layer_hidden_size
(
def
maybe_roundup_layer_hidden_size
(
...
@@ -169,5 +171,36 @@ def maybe_make_prepare_finalize(
...
@@ -169,5 +171,36 @@ def maybe_make_prepare_finalize(
physical_to_global
=
physical_to_global
,
physical_to_global
=
physical_to_global
,
local_expert_global_ids
=
local_expert_global_ids
,
local_expert_global_ids
=
local_expert_global_ids
,
)
)
elif
moe
.
use_mori_kernels
:
assert
quant_config
is
not
None
# Note: We may want to use FP8 dispatch just to reduce
# data movement.
use_fp8_dispatch
=
(
quant_config
.
is_per_act_token
or
quant_config
.
is_block_quantized
)
# For PTPC (per token per channel) quant, the scale dim for each token is 1
# For 1x128 quant, the scale dim for each token is hidden_dim // 128
scale_dim
=
1
if
quant_config
.
is_per_act_token
else
moe
.
hidden_dim
//
128
all_to_all_args
=
dict
(
rank
=
all2all_manager
.
rank
,
num_ep_ranks
=
all2all_manager
.
world_size
,
quant_dtype
=
quant_config
.
quant_dtype
,
token_hidden_size
=
moe
.
hidden_dim
,
scale_dim
=
scale_dim
,
scale_type_size
=
torch
.
float32
.
itemsize
,
max_num_tokens_per_dp_rank
=
moe
.
max_num_tokens
,
input_dtype
=
moe
.
in_dtype
,
num_local_experts
=
moe
.
num_experts
//
all2all_manager
.
world_size
,
num_experts_per_token
=
moe
.
experts_per_token
,
)
handle
=
all2all_manager
.
get_handle
(
all_to_all_args
)
prepare_finalize
=
MoriPrepareAndFinalize
(
handle
,
max_tokens_per_rank
=
moe
.
max_num_tokens
,
num_dispatchers
=
all2all_manager
.
world_size
,
use_fp8_dispatch
=
use_fp8_dispatch
,
)
return
prepare_finalize
return
prepare_finalize
vllm/model_executor/layers/fused_moe/config.py
View file @
49a12622
...
@@ -893,6 +893,10 @@ class FusedMoEParallelConfig:
...
@@ -893,6 +893,10 @@ class FusedMoEParallelConfig:
self
.
all2all_backend
in
[
"naive"
,
"allgather_reducescatter"
]
self
.
all2all_backend
in
[
"naive"
,
"allgather_reducescatter"
]
)
)
@
property
def
use_mori_kernels
(
self
):
return
self
.
use_all2all_kernels
and
self
.
all2all_backend
==
"mori"
@
staticmethod
@
staticmethod
def
flatten_tp_across_dp_and_pcp
(
def
flatten_tp_across_dp_and_pcp
(
tp_size
:
int
,
dp_size
:
int
,
dp_rank
:
int
,
pcp_size
:
int
,
pcp_rank
:
int
tp_size
:
int
,
dp_size
:
int
,
dp_rank
:
int
,
pcp_size
:
int
,
pcp_rank
:
int
...
@@ -1136,6 +1140,10 @@ class FusedMoEConfig:
...
@@ -1136,6 +1140,10 @@ class FusedMoEConfig:
def
use_deepep_ll_kernels
(
self
):
def
use_deepep_ll_kernels
(
self
):
return
self
.
moe_parallel_config
.
use_deepep_ll_kernels
return
self
.
moe_parallel_config
.
use_deepep_ll_kernels
@
property
def
use_mori_kernels
(
self
):
return
self
.
moe_parallel_config
.
use_mori_kernels
@
property
@
property
def
use_flashinfer_cutlass_kernels
(
self
):
def
use_flashinfer_cutlass_kernels
(
self
):
"""
"""
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
49a12622
...
@@ -570,6 +570,14 @@ class FusedMoE(CustomOp):
...
@@ -570,6 +570,14 @@ class FusedMoE(CustomOp):
self
.
moe_config_use_flashinfer_cutlass_kernels
=
(
self
.
moe_config_use_flashinfer_cutlass_kernels
=
(
self
.
moe_config
.
use_flashinfer_cutlass_kernels
self
.
moe_config
.
use_flashinfer_cutlass_kernels
)
)
if
self
.
use_mori_kernels
:
assert
self
.
rocm_aiter_fmoe_enabled
,
(
"Mori needs to be used with aiter fused_moe for now."
)
assert
not
self
.
aiter_fmoe_shared_expert_enabled
,
(
"Mori does not support fusion shared expert now. "
"Turn it off by setting VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0"
)
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
@@ -712,6 +720,10 @@ class FusedMoE(CustomOp):
...
@@ -712,6 +720,10 @@ class FusedMoE(CustomOp):
def
use_deepep_ll_kernels
(
self
):
def
use_deepep_ll_kernels
(
self
):
return
self
.
moe_parallel_config
.
use_deepep_ll_kernels
return
self
.
moe_parallel_config
.
use_deepep_ll_kernels
@
property
def
use_mori_kernels
(
self
):
return
self
.
moe_parallel_config
.
use_mori_kernels
@
property
@
property
def
use_flashinfer_cutlass_kernels
(
self
):
def
use_flashinfer_cutlass_kernels
(
self
):
return
(
return
(
...
@@ -729,6 +741,7 @@ class FusedMoE(CustomOp):
...
@@ -729,6 +741,7 @@ class FusedMoE(CustomOp):
return
(
return
(
self
.
moe_parallel_config
.
use_pplx_kernels
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
or
self
.
moe_parallel_config
.
use_mori_kernels
or
(
self
.
dp_size
>
1
and
self
.
use_flashinfer_cutlass_kernels
)
or
(
self
.
dp_size
>
1
and
self
.
use_flashinfer_cutlass_kernels
)
)
and
envs
.
VLLM_ENABLE_MOE_DP_CHUNK
)
and
envs
.
VLLM_ENABLE_MOE_DP_CHUNK
...
...
vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py
0 → 100644
View file @
49a12622
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
mori
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
class
MoriPrepareAndFinalize
(
mk
.
FusedMoEPrepareAndFinalize
):
"""
Prepare/Finalize using MoRI kernels.
"""
def
__init__
(
self
,
mori_op
:
mori
.
ops
.
EpDispatchCombineOp
,
max_tokens_per_rank
:
int
,
num_dispatchers
:
int
,
use_fp8_dispatch
:
bool
=
False
,
):
super
().
__init__
()
self
.
mori_op
=
mori_op
self
.
num_dispatchers_
=
num_dispatchers
self
.
max_tokens_per_rank
=
max_tokens_per_rank
self
.
use_fp8_dispatch
=
use_fp8_dispatch
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
def
output_is_reduced
(
self
)
->
bool
:
return
True
def
num_dispatchers
(
self
):
return
self
.
num_dispatchers_
def
max_num_tokens_per_rank
(
self
)
->
int
|
None
:
return
self
.
max_tokens_per_rank
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
return
torch
.
int32
def
supports_async
(
self
)
->
bool
:
return
False
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
"""
Returns a tuple of:
- quantized + dispatched a.
- Optional quantized + dispatched a1_scales.
- Optional ExpertTokensMetadata containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
"""
assert
not
apply_router_weight_on_input
,
(
"mori does not support apply_router_weight_on_input=True now."
)
scale
=
None
if
self
.
use_fp8_dispatch
:
from
aiter
import
QuantType
,
get_hip_quant
if
quant_config
.
is_block_quantized
:
quant_func
=
get_hip_quant
(
QuantType
.
per_1x128
)
a1
,
scale
=
quant_func
(
a1
,
quant_dtype
=
current_platform
.
fp8_dtype
())
elif
quant_config
.
is_per_act_token
:
quant_func
=
get_hip_quant
(
QuantType
.
per_Token
)
a1
,
scale
=
quant_func
(
a1
,
quant_dtype
=
current_platform
.
fp8_dtype
())
(
dispatch_a1
,
dispatch_weights
,
dispatch_scale
,
dispatch_ids
,
dispatch_recv_token_num
,
)
=
self
.
mori_op
.
dispatch
(
a1
,
topk_weights
,
scale
,
topk_ids
)
expert_tokens_meta
=
mk
.
ExpertTokensMetadata
(
expert_num_tokens
=
dispatch_recv_token_num
,
expert_num_tokens_cpu
=
None
)
return
(
dispatch_a1
,
dispatch_scale
,
expert_tokens_meta
,
dispatch_ids
,
dispatch_weights
,
)
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
)
->
None
:
num_token
=
output
.
shape
[
0
]
result
=
self
.
mori_op
.
combine
(
fused_expert_output
,
None
,
topk_ids
,
)[
0
]
output
.
copy_
(
result
[:
num_token
])
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
49a12622
...
@@ -188,6 +188,9 @@ def rocm_aiter_fused_experts(
...
@@ -188,6 +188,9 @@ def rocm_aiter_fused_experts(
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
quant_config
:
FusedMoEQuantConfig
|
None
=
None
,
quant_config
:
FusedMoEQuantConfig
|
None
=
None
,
a1q_scale
:
torch
.
Tensor
|
None
=
None
,
num_local_tokens
:
torch
.
Tensor
|
None
=
None
,
output_dtype
:
torch
.
dtype
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
quant_config
is
None
:
if
quant_config
is
None
:
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
...
@@ -216,6 +219,9 @@ def rocm_aiter_fused_experts(
...
@@ -216,6 +219,9 @@ def rocm_aiter_fused_experts(
assert
topk_weights
.
shape
[
-
1
]
==
1
,
(
assert
topk_weights
.
shape
[
-
1
]
==
1
,
(
"Only support topk=1 when `apply_router_weight_on_input` is True"
"Only support topk=1 when `apply_router_weight_on_input` is True"
)
)
assert
num_local_tokens
is
None
,
(
"AITER tkw1 kernel does not support `num_local_tokens`"
)
return
rocm_aiter_ops
.
asm_moe_tkw1
(
return
rocm_aiter_ops
.
asm_moe_tkw1
(
hidden_states
,
hidden_states
,
...
@@ -272,9 +278,11 @@ def rocm_aiter_fused_experts(
...
@@ -272,9 +278,11 @@ def rocm_aiter_fused_experts(
activation_method
=
activation_method
,
activation_method
=
activation_method
,
w1_scale
=
quant_config
.
w1_scale
,
w1_scale
=
quant_config
.
w1_scale
,
w2_scale
=
quant_config
.
w2_scale
,
w2_scale
=
quant_config
.
w2_scale
,
a1_scale
=
quant_config
.
a1_scale
,
a1_scale
=
quant_config
.
a1_scale
if
a1q_scale
is
None
else
a1q_scale
,
a2_scale
=
quant_config
.
a2_scale
,
a2_scale
=
quant_config
.
a2_scale
,
doweight_stage1
=
apply_router_weight_on_input
,
doweight_stage1
=
apply_router_weight_on_input
,
num_local_tokens
=
num_local_tokens
,
output_dtype
=
output_dtype
,
)
)
...
@@ -370,9 +378,12 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -370,9 +378,12 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
# TODO(rob): rocm_aiter_fused_experts uses self.quant_config's
# TODO(rob): rocm_aiter_fused_experts uses self.quant_config's
# a_scales for static quantization. Update this to fit better
# a_scales for static quantization. Update this to fit better
# with the interface once all quant integrations are complete.
# with the interface once all quant integrations are complete.
assert
a1q_scale
is
None
assert
a2_scale
==
self
.
quant_config
.
a2_scale
assert
a2_scale
==
self
.
quant_config
.
a2_scale
assert
expert_tokens_meta
is
None
if
expert_tokens_meta
is
not
None
:
num_local_tokens
=
expert_tokens_meta
.
expert_num_tokens
else
:
num_local_tokens
=
None
result
=
rocm_aiter_fused_experts
(
result
=
rocm_aiter_fused_experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -384,6 +395,8 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -384,6 +395,8 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
quant_config
=
self
.
quant_config
,
quant_config
=
self
.
quant_config
,
a1q_scale
=
a1q_scale
,
num_local_tokens
=
num_local_tokens
,
output_dtype
=
output
.
dtype
,
)
)
assert
result
.
shape
==
output
.
shape
output
.
copy_
(
result
)
output
.
copy_
(
result
)
vllm/platforms/rocm.py
View file @
49a12622
...
@@ -106,6 +106,12 @@ def on_gfx9() -> bool:
...
@@ -106,6 +106,12 @@ def on_gfx9() -> bool:
return
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
,
"gfx950"
])
return
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
,
"gfx950"
])
@
cache
def
on_gfx942
()
->
bool
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx942"
])
@
cache
@
cache
def
on_gfx950
()
->
bool
:
def
on_gfx950
()
->
bool
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
...
...
vllm/utils/import_utils.py
View file @
49a12622
...
@@ -451,3 +451,13 @@ def has_helion() -> bool:
...
@@ -451,3 +451,13 @@ def has_helion() -> bool:
# use helion...
# use helion...
"""
"""
return
_has_module
(
"helion"
)
return
_has_module
(
"helion"
)
def
has_aiter
()
->
bool
:
"""Whether the optional `aiter` package is available."""
return
_has_module
(
"aiter"
)
def
has_mori
()
->
bool
:
"""Whether the optional `mori` package is available."""
return
_has_module
(
"mori"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment