"...api/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "946bb53c566ef75a6c9417ca399b7e072d243c04"
Unverified Commit 755f3147 authored by Alex Sun's avatar Alex Sun Committed by GitHub
Browse files

[AMD] add aiter fused moe in DeepEP path (#7268)

parent 7c3a12c0
...@@ -54,10 +54,16 @@ from sglang.srt.utils import ( ...@@ -54,10 +54,16 @@ from sglang.srt.utils import (
_is_hip = is_hip() _is_hip = is_hip()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_hip: if _is_hip:
from vllm._custom_ops import scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
if _use_aiter:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
from aiter.ops.shuffle import shuffle_weight
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -1046,6 +1052,15 @@ class Fp8EPMoEMethod(Fp8MoEMethod): ...@@ -1046,6 +1052,15 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
w2_weight_scale, requires_grad=False w2_weight_scale, requires_grad=False
) )
layer.w2_input_scale = None layer.w2_input_scale = None
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
layer.w2_weight = torch.nn.Parameter(
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
return return
def apply( def apply(
...@@ -1117,18 +1132,36 @@ class DeepEPMoE(EPMoE): ...@@ -1117,18 +1132,36 @@ class DeepEPMoE(EPMoE):
assert ( assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm" ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
self.w13_weight_fp8 = ( if _use_aiter:
self.w13_weight, # expert_mask is of size (self.num_experts_per_partition + 1),
( # the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
self.w13_weight_scale_inv # for instance, if we have 4 experts on this rank, we would have a expert_mask like:
if self.use_block_quant # self.expert_mask = [1, 1, 1, 1, 0]
else self.w13_weight_scale # idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
), self.expert_mask = torch.zeros(
) (self.num_experts_per_partition + 1),
self.w2_weight_fp8 = ( device=torch.cuda.current_device(),
self.w2_weight, dtype=torch.int,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale, )
) # the last one is invalid rank_id
self.expert_mask[:-1] = 1
else:
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
),
)
def forward( def forward(
self, self,
...@@ -1142,6 +1175,9 @@ class DeepEPMoE(EPMoE): ...@@ -1142,6 +1175,9 @@ class DeepEPMoE(EPMoE):
num_recv_tokens_per_expert: List[int], num_recv_tokens_per_expert: List[int],
forward_mode: ForwardMode, forward_mode: ForwardMode,
): ):
if _use_aiter:
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
if resolved_deepep_mode == DeepEPMode.normal: if resolved_deepep_mode == DeepEPMode.normal:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
...@@ -1274,6 +1310,37 @@ class DeepEPMoE(EPMoE): ...@@ -1274,6 +1310,37 @@ class DeepEPMoE(EPMoE):
) )
return down_output return down_output
def forward_aiter(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
):
if hidden_states.shape[0] == 0:
return hidden_states
# in original deepep, idx == -1 meaning invalid and will not be processed.
# aiter does not accept -1, we use a expert mask to make these idx invalid
# (idx == num_experts_per_partition) meaning not used in aiter fused_moe
topk_idx_copy = topk_idx.to(torch.int32)
topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition
return fused_moe(
hidden_states,
self.w13_weight,
self.w2_weight,
topk_weights,
topk_idx_copy,
w1_scale=self.w13_weight_scale_inv,
w2_scale=self.w2_weight_scale_inv,
quant_type=QuantType.per_128x128,
activation=(
ActivationType.Silu
if self.activation == "silu"
else ActivationType.Gelu
),
expert_mask=self.expert_mask,
)
def forward_deepgemm_contiguous( def forward_deepgemm_contiguous(
self, self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
......
...@@ -6,7 +6,13 @@ from sglang.srt.managers.expert_distribution import ( ...@@ -6,7 +6,13 @@ from sglang.srt.managers.expert_distribution import (
get_global_expert_distribution_recorder, get_global_expert_distribution_recorder,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import DeepEPMode, get_int_env_var, load_json_config from sglang.srt.utils import (
DeepEPMode,
get_bool_env_var,
get_int_env_var,
is_hip,
load_json_config,
)
try: try:
from deep_ep import Buffer, Config from deep_ep import Buffer, Config
...@@ -32,6 +38,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -32,6 +38,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -376,6 +384,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -376,6 +384,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
""" """
if _use_aiter:
# skip permutation here as aiter fused_moe has fused inside
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
)
return reorder_topk_ids, seg_indptr, hidden_states
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess( reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
topk_idx, self.num_experts topk_idx, self.num_experts
...@@ -409,7 +426,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -409,7 +426,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
): ):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
output = hidden_states output = hidden_states
else: else:
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment