Unverified Commit 726efe17 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Move the shared/fused expert output sum into MoERunnerBase (#35949)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 59556265
......@@ -236,7 +236,6 @@ class MoETestConfig:
use_gate: bool
use_routed_input_transform: bool
enable_eplb: bool = False
reduce_results: bool = False
backend: str | None = None
ep_size: int = 1
dp_size: int = 1
......@@ -295,7 +294,6 @@ def generate_valid_test_configs(
use_shared_experts,
use_gate,
use_routed_input_transform,
reduce_results,
) in product(
SHAPE_COMBOS,
NUM_EXPERTS,
......@@ -304,7 +302,6 @@ def generate_valid_test_configs(
[False, True], # shared
[False, True], # gate
[False, True], # routed input exform
[False, True], # reduce results
):
config = MoETestConfig(
shape[0], # m
......@@ -318,7 +315,6 @@ def generate_valid_test_configs(
use_gate,
use_routed_input_transform,
enable_eplb,
reduce_results,
backend,
ep_size,
dp_size,
......@@ -395,18 +391,7 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]:
and config.backend.startswith("flashinfer_nvlink")
and not current_platform.has_device_capability(90)
):
return False, "flashinfer_nvlink needs an H100+ GPUs"
# reduce_results incompatibilities
if config.reduce_results and config.use_shared_experts:
return False, "reduce_results=True is not compatible with shared_experts=True"
if config.reduce_results and config.quantization is not None:
return (
False,
"reduce_results=True only tested with unquantized data types in "
"order to limit number of tests run",
)
return False, "flashinfer_nvlink needs H100+ GPUs"
# Backend-specific checks
if config.backend is not None:
......@@ -448,10 +433,6 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]:
if config.enable_eplb and config.backend not in EPLB_SUPPORTED_BACKENDS:
return False, f"EPLB not supported with {config.backend}."
world_size = config.tp_size * config.dp_size
if config.reduce_results and world_size == 1:
return False, "reduce_results=True only makes sense for multi-GPU tests"
if (
config.backend is not None
and config.backend.startswith("flashinfer_nvlink")
......@@ -846,7 +827,6 @@ def make_fused_moe_layer(
tp_size: int,
ep_size: int,
dp_size: int,
reduce_results: bool,
w1: torch.Tensor,
w2: torch.Tensor,
top_k: int,
......@@ -874,7 +854,7 @@ def make_fused_moe_layer(
routed_input_transform: torch.nn.Module | None = None,
routed_output_transform: torch.nn.Module | None = None,
pcp_size: int | None = 1,
) -> tuple[Callable, FusedMoE]:
) -> FusedMoE:
quant_config, qw = make_quant_config(quantization, w1, w2, global_num_experts)
kwargs = dict()
......@@ -887,8 +867,10 @@ def make_fused_moe_layer(
# Add gate and routed_input_transform if provided
if gate is not None:
kwargs["gate"] = gate
if routed_input_transform is not None:
kwargs["routed_input_transform"] = routed_input_transform
kwargs["routed_output_transform"] = routed_output_transform
layer = builder(
num_experts=global_num_experts,
......@@ -896,7 +878,6 @@ def make_fused_moe_layer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=in_dtype,
reduce_results=reduce_results,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
......@@ -936,36 +917,7 @@ def make_fused_moe_layer(
layer.quant_method.process_weights_after_loading(layer)
def _moe(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
if shared_experts is None:
final_shared_states = None
final_hidden_states = layer(hidden_states, router_logits)
else:
final_shared_states, final_hidden_states = layer(
hidden_states, router_logits
)
# Apply routed output transform if provided
# (e.g., latent space -> original space)
if routed_output_transform is not None:
final_hidden_states = routed_output_transform(final_hidden_states)
if shared_experts is not None:
assert not reduce_results
assert final_shared_states is not None
final_hidden_states += final_shared_states
if not reduce_results and layer.tp_size > 1:
final_hidden_states = layer.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states
return _moe, layer
return layer
def make_fake_moe_layer(
......@@ -999,7 +951,6 @@ def make_fake_moe_layer(
tp_size: int = 1,
dp_size: int = 1,
ep_size: int = 1,
reduce_results: bool = False,
) -> Callable:
activation = MoEActivation.from_str(activation)
......@@ -1101,7 +1052,7 @@ def make_fake_moe_layer(
def _test_body_regular(
moe_fn: Callable,
moe_layer: Callable,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
vllm_config: VllmConfig,
......@@ -1118,13 +1069,12 @@ def _test_body_regular(
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
):
output = moe_fn(hidden_states, router_logits)
output = moe_layer(hidden_states, router_logits)
return baseline_output, output
def _test_body_eplb(
moe_fn: Callable,
moe_layer: FusedMoE,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -1145,7 +1095,6 @@ def _test_body_eplb(
n: int,
top_k: int,
shared_experts,
reduce_results: bool,
gate: torch.nn.Module | None,
routed_input_transform: torch.nn.Module | None,
routed_output_transform: torch.nn.Module | None,
......@@ -1161,7 +1110,7 @@ def _test_body_eplb(
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
):
output_before = moe_fn(hidden_states, router_logits)
output_before = moe_layer(hidden_states, router_logits)
# Create a fresh FusedMoE layer with enable_eplb=True
# Delete the original layer's registration so the constructor can
......@@ -1174,7 +1123,7 @@ def _test_body_eplb(
# When using routed_input_transform, experts operate in latent space
hidden_size_for_layer = k // 2 if routed_input_transform is not None else k
moe_fn, moe_layer = make_fused_moe_layer(
eplb_moe_layer = make_fused_moe_layer(
quantization=quantization,
use_ep=use_ep,
hidden_size=hidden_size_for_layer,
......@@ -1183,7 +1132,6 @@ def _test_body_eplb(
tp_size=tp_size,
ep_size=ep_size,
dp_size=dp_size,
reduce_results=reduce_results,
w1=w1,
w2=w2,
top_k=top_k,
......@@ -1196,14 +1144,14 @@ def _test_body_eplb(
)
# Necessary?
if moe_layer._expert_map is not None:
moe_layer._expert_map = moe_layer._expert_map.to(device)
if eplb_moe_layer._expert_map is not None:
eplb_moe_layer._expert_map = eplb_moe_layer._expert_map.to(device)
# All ranks must generate the same permutation
initial_indices = torch.arange(num_experts, dtype=torch.long)
shuffled_indices = initial_indices[torch.randperm(num_experts)]
expert_weights = [list(moe_layer.get_expert_weights())]
expert_weights = [list(eplb_moe_layer.get_expert_weights())]
communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
......@@ -1227,7 +1175,7 @@ def _test_body_eplb(
num_experts, dtype=torch.int32, device=device
)
moe_layer.set_eplb_state(
eplb_moe_layer.set_eplb_state(
moe_layer_idx=0,
expert_load_view=torch.zeros(
(1, num_experts),
......@@ -1244,7 +1192,7 @@ def _test_body_eplb(
),
)
moe_layer.eplb_state.should_record_tensor = torch.ones(
eplb_moe_layer.eplb_state.should_record_tensor = torch.ones(
(), dtype=torch.bool, device=device
)
......@@ -1255,7 +1203,7 @@ def _test_body_eplb(
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
):
output_after = moe_fn(hidden_states, router_logits)
output_after = eplb_moe_layer(hidden_states, router_logits)
return output_before, output_after
......@@ -1274,7 +1222,6 @@ def _run_one_config(
num_experts: int,
top_k: int,
quantization: str | None,
reduce_results: bool,
backend: str | None,
test_body_fn: Callable,
use_shared_experts: bool,
......@@ -1341,7 +1288,6 @@ def _run_one_config(
tp_size=tp_size,
ep_size=ep_size,
dp_size=dp_size,
reduce_results=reduce_results,
)
baseline_output = baseline_layer(hidden_states, router_logits)
......@@ -1369,7 +1315,7 @@ def _run_one_config(
hidden_size_for_layer = k // 2 if routed_input_transform is not None else k
# Create initial MoE layer
moe_fn, moe_layer = make_fused_moe_layer(
moe_layer = make_fused_moe_layer(
quantization=quantization,
use_ep=use_ep,
hidden_size=hidden_size_for_layer,
......@@ -1378,7 +1324,6 @@ def _run_one_config(
tp_size=tp_size,
ep_size=ep_size,
dp_size=dp_size,
reduce_results=reduce_results,
w1=w1,
w2=w2,
top_k=top_k,
......@@ -1402,7 +1347,6 @@ def _run_one_config(
# Call the test body function with all necessary context
expected, actual = test_body_fn(
moe_fn=moe_fn,
moe_layer=moe_layer,
hidden_states=hidden_states,
router_logits=router_logits,
......@@ -1423,7 +1367,6 @@ def _run_one_config(
m=m,
top_k=top_k,
shared_experts=shared_experts,
reduce_results=reduce_results,
gate=gate,
routed_input_transform=routed_input_transform,
routed_output_transform=routed_output_transform,
......@@ -1520,7 +1463,6 @@ def test_moe_layer_no_parallel(
test_config.num_experts,
test_config.top_k,
test_config.quantization,
test_config.reduce_results,
test_config.backend,
_test_body_regular,
use_shared_experts=test_config.use_shared_experts,
......@@ -1578,7 +1520,6 @@ def _parallel_worker(
test_config.num_experts,
test_config.top_k,
test_config.quantization,
test_config.reduce_results,
test_config.backend,
functools.partial(
_test_body_config, test_config=test_config, cpu_group=cpu_group
......@@ -1597,7 +1538,7 @@ def _parallel_worker(
failed = failed + 1
if verbosity > 0:
traceback.print_exc()
print(f"\n{str(ex)}\nFAILED {ex.__class__}")
print(f"\n{str(ex)}\nFAILED")
else:
print("F", end="")
finally:
......
......@@ -165,7 +165,6 @@ def test_routed_input_transform_inside_vs_outside(
top_k=top_k,
hidden_size=latent_size,
intermediate_size=intermediate_size,
reduce_results=False,
renormalize=True,
params_dtype=dtype,
tp_size=1,
......@@ -183,7 +182,6 @@ def test_routed_input_transform_inside_vs_outside(
top_k=top_k,
hidden_size=latent_size,
intermediate_size=intermediate_size,
reduce_results=False,
renormalize=True,
params_dtype=dtype,
tp_size=1,
......@@ -212,34 +210,20 @@ def test_routed_input_transform_inside_vs_outside(
hidden_states = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype)
router_logits = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
# Clone inputs so any in-place modification by Method A
# cannot affect Method B's computation.
hidden_states_A = hidden_states.clone()
router_logits_A = router_logits.clone()
with set_forward_context(None, vllm_config, num_tokens=num_tokens):
shared_out_A, routed_out_A = moe_with_transform(
hidden_states_A, router_logits_A
)
# Method A: combined output (shared + routed)
combined_A = moe_with_transform(hidden_states, router_logits)
# Method B: manually transform, get routed output, add shared
transformed_hidden = routed_transform(hidden_states)
shared_out_B, routed_out_B = moe_without_transform(
transformed_hidden, router_logits
)
expected_shared_out = shared_experts(hidden_states)
routed_out_B = moe_without_transform(transformed_hidden, router_logits)
shared_out_B = shared_experts(hidden_states)
combined_B = shared_out_B + routed_out_B
_assert_close(
routed_out_A,
routed_out_B,
atol=1e-3,
rtol=1e-3,
label="Routed output: transform inside vs outside",
)
_assert_close(
shared_out_A,
expected_shared_out,
torch.testing.assert_close(
combined_A,
combined_B,
atol=1e-3,
rtol=1e-3,
label="Shared expert output",
msg="Combined output should match: transform inside vs outside",
)
......@@ -592,9 +592,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs)
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
@property
def quant_method(self):
return self.base_layer.quant_method
......
......@@ -716,7 +716,7 @@ class MarlinExperts(MarlinExpertsBase):
):
assert self.w1_scale is not None
assert self.w2_scale is not None
return fused_marlin_moe(
fused_marlin_moe(
hidden_states=hidden_states,
w1=w1,
w2=w2,
......
......@@ -230,11 +230,18 @@ class FusedMoE(PluggableLayer):
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all_reduce on the output of the layer
renormalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
enable_eplb: Whether to enable expert parallelism load balancer.
router_logits_dtype: Data type for router logits buffers.
routed_scaling_factor: A scaling factor that is applied to the topk_weights
by the router or the output of the layer depending
on the value of `apply_routed_scale_to_output`
apply_routed_scale_to_output: Determine whether or not `routed_scaling_factor`
is applied to the topk_weights or to the experts
output. It is applied to the experts output
instead of the topk_weights when this feature is
not supported by the router (or the experts).
"""
# --8<-- [end:fused_moe]
......@@ -246,7 +253,6 @@ class FusedMoE(PluggableLayer):
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype | None = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: int | None = None,
......@@ -274,12 +280,12 @@ class FusedMoE(PluggableLayer):
gate: torch.nn.Module | None = None,
shared_experts: torch.nn.Module | None = None,
routed_input_transform: torch.nn.Module | None = None,
routed_output_transform: torch.nn.Module | None = None,
apply_routed_scale_to_output: bool = False,
zero_expert_type: str | None = None,
):
super().__init__()
self._routed_input_transform = routed_input_transform
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
......@@ -425,7 +431,6 @@ class FusedMoE(PluggableLayer):
assert intermediate_size % self.tp_size == 0
intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
# TODO(bnell): these attributes are only used by monolithic kernels.
......@@ -437,7 +442,14 @@ class FusedMoE(PluggableLayer):
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
# When apply_routed_scale_to_output is True, we set the scaling factor
# to 1.0 so it ends up being a nop. Applying the scale will be handled
# by the runner in this case.
# The member variable must be set in the same way as the router since
# some quantization methods can access it.
self.routed_scaling_factor = (
routed_scaling_factor if not apply_routed_scale_to_output else 1.0
)
self.e_score_correction_bias = e_score_correction_bias
# TODO(bnell): end attributes
......@@ -456,7 +468,7 @@ class FusedMoE(PluggableLayer):
topk_group=topk_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=self.num_fused_shared_experts,
enable_eplb=enable_eplb,
......@@ -578,12 +590,18 @@ class FusedMoE(PluggableLayer):
layer_name=self.layer_name,
moe_config=self.moe_config,
router=self.router,
routed_input_transform=self._routed_input_transform,
gate=gate,
shared_experts=shared_experts,
quant_method=self.quant_method,
reduce_results=self.reduce_results,
enable_dbo=self.vllm_config.parallel_config.enable_dbo,
routed_input_transform=routed_input_transform,
routed_output_transform=routed_output_transform,
# When apply_routed_scale_to_output is True, we allow
# the scaling factor to be passed to the runner, otherwise
# we pass 1.0 so it ends up being a nop.
routed_scaling_factor=routed_scaling_factor
if apply_routed_scale_to_output
else 1.0,
)
# TODO(bnell): This method is provided as a hook so vllm/lora/layers/fused_moe.py
......@@ -1514,32 +1532,11 @@ class FusedMoE(PluggableLayer):
self.ensure_moe_quant_config_init()
return self.quant_method.moe_quant_config
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
layer. The result of this function is typically used as
the reduce_results argument to the module.
When just tensor-parallel is used, it is not required to reduce
the shared_experts results immediately. Instead we reduce at the
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
With EP and all2all kernels - this is no longer viable as all
GPU ranks in DP, produce the complete set of hidden_states.
Therefore it is required that we reduce the shared_experts output
early.
"""
return self.runner.must_reduce_shared_expert_outputs()
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
"""
Some combine kernels reduce across GPU ranks by default.
"""
return self.runner.maybe_all_reduce_tensor_model_parallel(final_hidden_states)
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
return self.runner.forward(
hidden_states,
router_logits,
......@@ -1613,7 +1610,6 @@ class FusedMoE(PluggableLayer):
f"intermediate_size_per_partition={self.intermediate_size_per_partition}, " # noqa: E501
f"tp_size={self.tp_size},\n"
f"ep_size={self.ep_size}, "
f"reduce_results={self.reduce_results}, "
)
return s
......
......@@ -1261,7 +1261,7 @@ class FusedMoEKernelModularImpl:
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap.
......
......@@ -37,10 +37,6 @@ class DefaultMoERunner(MoERunnerBase):
for different configurations (e.g., with/without shared experts, gates, etc.).
"""
@property
def reduce_results(self) -> bool:
return self._reduce_results
@property
def do_naive_dispatch_combine(self) -> bool:
return (
......
......@@ -26,18 +26,7 @@ class MoERunner(ABC):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@abstractmethod
def must_reduce_shared_expert_outputs(self) -> bool:
raise NotImplementedError
@abstractmethod
def maybe_all_reduce_tensor_model_parallel(
self,
final_hidden_states: torch.Tensor,
):
) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
......
......@@ -81,7 +81,9 @@ def _resolve_layer_name(layer_name: str | LayerName) -> str:
# Note: _moe_forward and _moe_forward_shared should not contain any
# implementation details, They should merely pass along control to
# the runner's 'forward_dispatch' method.
# the runner's '_forward_dispatch' method.
# These functions should never be called directly since they do not
# include all the functionality of the MoE layer.
def _moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -89,7 +91,7 @@ def _moe_forward(
layer_name: _layer_name_type,
) -> torch.Tensor:
layer = get_layer_from_name(_resolve_layer_name(layer_name))
return layer.runner.forward_dispatch(
return layer.runner._forward_dispatch(
layer,
hidden_states,
router_logits,
......@@ -113,7 +115,7 @@ def _moe_forward_shared(
layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]:
layer = get_layer_from_name(_resolve_layer_name(layer_name))
return layer.runner.forward_dispatch(
return layer.runner._forward_dispatch(
layer,
hidden_states,
router_logits,
......@@ -143,7 +145,7 @@ def _moe_forward_shared_fake(
direct_register_custom_op(
op_name="moe_forward",
op_func=_moe_forward,
mutates_args=["hidden_states"], # is this still true?
mutates_args=["hidden_states"],
fake_impl=_moe_forward_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
......@@ -157,6 +159,15 @@ direct_register_custom_op(
)
def _unpack(
result: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
) -> tuple[torch.Tensor | None, torch.Tensor]:
if isinstance(result, tuple):
return result
else:
return (None, result)
class MoERunnerBase(MoERunner):
"""
Abstract base class providing common functionality for MoE runner implementations.
......@@ -174,7 +185,6 @@ class MoERunnerBase(MoERunner):
allowing flexibility in the actual MoE computation implementation.
Key abstract methods that subclasses must implement:
- reduce_results: Determines whether results should be reduced across ranks
- _forward_impl: The core MoE computation logic specific to each runner type
"""
......@@ -187,17 +197,23 @@ class MoERunnerBase(MoERunner):
gate: torch.nn.Module | None,
shared_experts: torch.nn.Module | None,
quant_method: FusedMoEMethodBase,
reduce_results: bool,
enable_dbo: bool,
routed_output_transform: torch.nn.Module | None = None,
routed_scaling_factor: float = 1.0,
):
super().__init__()
self.moe_config = moe_config
self.router = router
self.routed_input_transform = routed_input_transform
self.routed_output_transform = routed_output_transform
self.routed_scaling_factor = routed_scaling_factor
self.gate = gate
self.quant_method = quant_method
self._reduce_results = reduce_results
self.enable_dbo = enable_dbo
self._fused_output_is_reduced = (
self.quant_method.moe_kernel is not None
and self.quant_method.moe_kernel.output_is_reduced()
)
self._shared_experts: SharedExperts | None = None
if shared_experts is not None:
......@@ -209,7 +225,6 @@ class MoERunnerBase(MoERunner):
# called, i.e. by a MK or by the MoERunner.
# Once the MK can be created upfront, we can just pass in the proper
# flags derived from the quant_method's MK.
reduce_results=reduce_results,
quant_method=quant_method,
enable_dbo=enable_dbo,
)
......@@ -217,7 +232,7 @@ class MoERunnerBase(MoERunner):
# Needed for string -> FusedMoE layer lookup in custom ops.
self.layer_name = layer_name
self.forward_entry = self._select_forward()
self._forward_entry = self._select_forward()
def _select_forward(self) -> Callable:
if current_platform.is_tpu() or current_platform.is_cpu():
......@@ -245,38 +260,6 @@ class MoERunnerBase(MoERunner):
def is_internal_router(self) -> bool:
return self.gate is not None
@property
@abstractmethod
def reduce_results(self) -> bool:
raise NotImplementedError
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
layer. The result of this function is typically used as
the reduce_results argument to the module.
When just tensor-parallel is used, it is not required to reduce
the shared_experts results immediately. Instead we reduce at the
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
With EP and all2all kernels - this is no longer viable as all
GPU ranks in DP, produce the complete set of hidden_states.
Therefore it is required that we reduce the shared_experts output
early.
"""
return (
self.quant_method.moe_kernel is not None
and self.quant_method.moe_kernel.output_is_reduced()
)
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
"""
Some combine kernels reduce across GPU ranks by default.
"""
if self.must_reduce_shared_expert_outputs():
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
def apply_routed_input_transform(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor | None]:
......@@ -286,10 +269,6 @@ class MoERunnerBase(MoERunner):
is saved separately so shared experts get [S, hidden_size] while
routed experts get the transformed [S, moe_latent_size].
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
moved inside SharedFusedMoE to all-reduce on the smaller latent
dimension.
Returns (possibly transformed) hidden states and the input for shared
experts (or None if there are no shared experts).
"""
......@@ -306,33 +285,79 @@ class MoERunnerBase(MoERunner):
hidden_states if self._shared_experts is not None else None,
)
def _maybe_reduce_output(
def apply_routed_output_transform(
self,
states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
trunc_sizes: list[int],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
def trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
return x[..., :trunc_size]
fused_output: torch.Tensor,
) -> torch.Tensor:
"""Apply transform to routed expert output (e.g., latent to full dim).
Used by latent MoE models (e.g., NemotronH) where routed experts
operate in a compressed latent space and need projection back to
the full hidden dimension before combining with shared expert output.
"""
if self.routed_output_transform is not None:
r = self.routed_output_transform(fused_output)
fused_output = r[0] if isinstance(r, tuple) else r
return fused_output
def reduce_and_trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
return trunc(self.maybe_all_reduce_tensor_model_parallel(x), trunc_size)
def _maybe_apply_routed_scale_to_output(
self,
shared_output: torch.Tensor | None,
fused_output: torch.Tensor,
) -> tuple[torch.Tensor | None, torch.Tensor]:
"""Apply routed_scaling_factor to the output with FP16 overflow
protection.
Scale the fused expert output by routed_scaling_factor. For FP16,
avoid overflow by dividing shared_output by the scale instead
(the decoder layer compensates with matching divisions).
"""
if self.routed_scaling_factor != 1.0:
if fused_output.dtype != torch.float16:
fused_output *= self.routed_scaling_factor
elif shared_output is not None:
shared_output *= 1.0 / self.routed_scaling_factor
return shared_output, fused_output
def _maybe_reduce_shared_expert_output(
self,
shared_output: torch.Tensor | None,
) -> torch.Tensor | None:
"""All-reduce shared expert output when the combine kernel already
reduced fused output.
This is the "early" all-reduce path. When the combine kernel produces
already-reduced fused output, shared output must be reduced separately
to match.
"""
if self._fused_output_is_reduced:
assert shared_output is not None
shared_output = tensor_model_parallel_all_reduce(shared_output)
return shared_output
def _maybe_reduce_final_output(
self,
states: torch.Tensor,
trunc_size: int,
) -> torch.Tensor:
"""Truncate padded dimensions and all-reduce the combined output.
This is the "late" all-reduce path. When neither fused nor shared
output was individually reduced, the combined sum is all-reduced
here. Skipped when sequence-parallel is active (SP handles its
own reduction) or when the early path already reduced both outputs.
"""
# We don't need to reduce the final output if:
# - We are not running with TP or DP
# - The MK already reduced the fused output itself.
if (
not self.moe_config.is_sequence_parallel
and self.reduce_results
and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1)
and not self._fused_output_is_reduced
):
func = reduce_and_trunc
else:
func = trunc
states = tensor_model_parallel_all_reduce(states)
if isinstance(states, tuple):
return tuple(
[func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
)
else:
assert len(trunc_sizes) == 1
return func(states, trunc_sizes[0])
return states[..., :trunc_size]
def _encode_layer_name(self) -> str | LayerName:
if _USE_LAYERNAME:
......@@ -349,7 +374,15 @@ class MoERunnerBase(MoERunner):
self,
shared_experts_input: torch.Tensor | None,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, list[int]]:
) -> tuple[torch.Tensor, int]:
"""Pad hidden_states to moe_config.hidden_dim and compute the
original dimension for later truncation.
For latent MoE, the routed hidden_states may be smaller than
hidden_dim. Padding ensures uniform tensor sizes through the
fused MoE kernel. The returned trunc_size is used by
_maybe_reduce_final_output to strip the padding from the result.
"""
shared_experts_hidden_dim = (
shared_experts_input.shape[-1] if shared_experts_input is not None else 0
)
......@@ -365,10 +398,10 @@ class MoERunnerBase(MoERunner):
value=0.0,
)
if self._shared_experts is not None:
orig_hidden_dims = [shared_experts_hidden_dim, transformed_hidden_dim]
if self.routed_output_transform is not None and shared_experts_hidden_dim > 0:
orig_hidden_dims = shared_experts_hidden_dim
else:
orig_hidden_dims = [transformed_hidden_dim]
orig_hidden_dims = transformed_hidden_dim
return hidden_states, orig_hidden_dims
......@@ -388,6 +421,12 @@ class MoERunnerBase(MoERunner):
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> tuple[torch.Tensor | None, torch.Tensor]:
"""Run expert routing and the fused MoE kernel via the quant method.
Orchestrates shared expert execution (before/after), expert selection
via the router, and the actual fused MoE computation. Returns
(shared_expert_output, fused_expert_output).
"""
# Run this before quant_method to avoid inplace issues.
# TODO(bnell): probably not needed anymore since inplace is
# disabled when shared experts are present.
......@@ -428,6 +467,13 @@ class MoERunnerBase(MoERunner):
)
def _sequence_parallel_context(self):
"""Return a context manager for sequence-parallel token
redistribution.
When sequence parallelism is active, returns a context that handles
local size tracking for proper token scatter/gather. Otherwise
returns a no-op context.
"""
ctx = get_forward_context()
return (
ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
......@@ -448,14 +494,17 @@ class MoERunnerBase(MoERunner):
def _maybe_add_zero_expert_output(
self,
result: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
result: torch.Tensor,
) -> torch.Tensor:
"""Add the zero expert's contribution to the final result.
When a ZeroExpertRouter is used, it computes a bias-like output
from the "zero expert" that is added to the combined routed+shared
expert output.
"""
if isinstance(self.router, ZeroExpertRouter):
zero_expert_output = self.router.zero_expert_output
assert zero_expert_output is not None
if isinstance(result, tuple):
result = (result[0], result[1] + zero_expert_output)
else:
result = result + zero_expert_output
return result
......@@ -463,7 +512,7 @@ class MoERunnerBase(MoERunner):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
"""Invoke the fused moe layer.
Input:
......@@ -472,55 +521,88 @@ class MoERunnerBase(MoERunner):
Output:
- The new hidden_states.
or
- A tuple of (shared experts output, new hidden_states).
Calling sequence
- forward
- self.forward_entry (_moe_forward or _moe_forward_shared custom op)
- forward_dispatch
- self._forward_entry (_moe_forward or _moe_forward_shared custom op)
- _forward_dispatch
- _forward_impl
Note: The existence of _moe_forward and _moe_forward_shared custom ops are due
to the following reasons:
1. the chunking loop in ChunkingMoERunner._forward_impl cannot be compiled by
torch.compile
2. pytorch cannot handle union types in custom op signatures so _moe_forward
and _moe_forward_shared must be split.
2. pytorch cannot handle union types in custom op signatures so
_moe_forward and _moe_forward_shared must be split.
If ChunkingMoERunner._forward_impl can be implemented via torch.scan we can
potentially get rid of _moe_forward and _moe_forward_shared and collapse the
whole sequence into the 'forward' method.
"""
# Apply transform for routed experts (e.g., latent projection for latent MoE)
# Apply transform for routed experts (e.g., latent projection
# for latent MoE)
hidden_states, shared_experts_input = self.apply_routed_input_transform(
hidden_states
)
hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(
hidden_states, og_hidden_dim = self._maybe_pad_hidden_states(
shared_experts_input,
hidden_states,
)
fused_output = self.forward_entry(
result = self._forward_entry(
hidden_states,
router_logits,
shared_experts_input,
self._encode_layer_name(),
)
result = self._maybe_reduce_output(fused_output, og_hidden_dims)
#
# Note: there are two all-reduce points below. They are mutually
# exclusive, controlled by _fused_output_is_reduced
# - When True: the combine kernel already reduced fused_output,
# so we reduce shared_output here to match, then skip the
# all-reduce in _maybe_reduce_final_output.
# - When False: neither output is reduced yet, so we combine
# them first and all-reduce the sum in _maybe_reduce_final_output.
# Extract outputs from result
shared_output, fused_output = _unpack(result)
# If combine kernel already reduced fused, reduce shared to match.
# See note above re: the two all-reduce points.
shared_output = self._maybe_reduce_shared_expert_output(shared_output)
shared_output, fused_output = self._maybe_apply_routed_scale_to_output(
shared_output, fused_output
)
# Apply output transform (e.g. latent -> full dim)
fused_output = self.apply_routed_output_transform(fused_output)
if shared_output is not None:
result = shared_output + fused_output
else:
result = fused_output
result = self._maybe_reduce_final_output(result, og_hidden_dim)
return self._maybe_add_zero_expert_output(result)
def forward_dispatch(
def _forward_dispatch(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Entry point called by the custom op to run the MoE computation.
Handles pre-dispatch setup (gate application, external shared expert
triggering, quant config init) then delegates to _forward_impl within
the sequence-parallel context.
"""
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
......@@ -549,4 +631,11 @@ class MoERunnerBase(MoERunner):
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Core MoE computation to be implemented by subclasses.
Performs expert routing, fused MoE kernel execution, and shared
expert computation. Returns a single tensor (fused output only)
or a tuple of (shared_output, fused_output) when shared experts
are present.
"""
raise NotImplementedError
......@@ -29,8 +29,9 @@ def create_moe_runner(
gate: torch.nn.Module | None,
shared_experts: SharedExperts | None,
quant_method: FusedMoEMethodBase,
reduce_results: bool,
enable_dbo: bool,
routed_output_transform: torch.nn.Module | None = None,
routed_scaling_factor: float = 1.0,
) -> MoERunner:
return DefaultMoERunner(
layer_name,
......@@ -40,6 +41,7 @@ def create_moe_runner(
gate,
shared_experts,
quant_method,
reduce_results,
enable_dbo,
routed_output_transform=routed_output_transform,
routed_scaling_factor=routed_scaling_factor,
)
......@@ -5,10 +5,6 @@ from enum import IntEnum
import torch
import vllm.envs as envs
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
......@@ -48,7 +44,6 @@ class SharedExperts:
layer: torch.nn.Module,
moe_config: FusedMoEConfig,
quant_method: QuantizeMethodBase,
reduce_results: bool,
enable_dbo: bool,
):
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
......@@ -68,7 +63,6 @@ class SharedExperts:
self._layer = layer
self._moe_config = moe_config
self._quant_method = quant_method
self._reduce_results = reduce_results
# Allow disabling of the separate shared experts stream for
# debug purposes.
......@@ -139,18 +133,6 @@ class SharedExperts:
return output
def _maybe_reduce_shared_out(self, shared_out: torch.Tensor) -> torch.Tensor:
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
if (
self._reduce_results
and self._quant_method.moe_kernel is not None
and self._quant_method.moe_kernel.output_is_reduced()
and get_tensor_model_parallel_world_size() > 1
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out
@property
def _output_idx(self) -> int:
return dbo_current_ubatch_id() if self.enable_dbo else 0
......
......@@ -18,12 +18,8 @@ class SharedFusedMoE(FusedMoE):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result = super().forward(
) -> torch.Tensor:
return super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
if self.shared_experts is None:
return None, result
else:
return result
......@@ -100,7 +100,7 @@ class AXK1MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.routed_scaling_factor = config.routed_scaling_factor
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
self.ep_group = get_ep_group().device_group
self.ep_rank = get_ep_group().rank_in_group
......@@ -170,7 +170,6 @@ class AXK1MoE(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
......@@ -180,9 +179,8 @@ class AXK1MoE(nn.Module):
scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul
# aiter applies routed_scaling_factor internally
routed_scaling_factor=1.0
if not self.is_rocm_aiter_moe_enabled
else self.routed_scaling_factor,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scale_to_output=not self.is_rocm_aiter_moe_enabled,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
......@@ -204,43 +202,20 @@ class AXK1MoE(nn.Module):
hidden_states = sequence_parallel_chunk(hidden_states)
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=hidden_states
)
else:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
shared_output, final_hidden_states = fused_moe_out
if self.shared_experts is None:
assert shared_output is None
# Fix FP16 overflow
# See AXK1DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
if not self.is_rocm_aiter_moe_enabled:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim)
......
......@@ -131,7 +131,6 @@ class AfmoeMoE(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.route_norm if self.score_func == "sigmoid" else False,
quant_config=quant_config,
use_grouped_topk=True,
......@@ -152,20 +151,10 @@ class AfmoeMoE(nn.Module):
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
fused_moe_out = self.experts(
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self.shared_experts is not None:
shared_output, final_hidden_states = fused_moe_out
final_hidden_states = final_hidden_states + shared_output
else:
final_hidden_states = fused_moe_out
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim)
......
......@@ -283,7 +283,6 @@ class AriaTextMoELayer(nn.Module):
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
reduce_results=True,
prefix=f"{prefix}.experts",
)
......@@ -301,12 +300,7 @@ class AriaTextMoELayer(nn.Module):
router_output = torch.nn.functional.linear(hidden_states, self.router_weight)
sparse_expert_output = self.experts(hidden_states, router_output)
if self.shared_experts is not None:
return sparse_expert_output[0] + sparse_expert_output[1]
else:
return sparse_expert_output
return self.experts(hidden_states, router_output)
class AriaTextDecoderLayer(LlamaDecoderLayer):
......
......@@ -291,7 +291,6 @@ class BailingMoE(nn.Module):
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......@@ -301,6 +300,7 @@ class BailingMoE(nn.Module):
topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk,
router_logits_dtype=self.router_dtype,
routed_scaling_factor=self.routed_scaling_factor,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......@@ -314,21 +314,6 @@ class BailingMoE(nn.Module):
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self.shared_experts is not None:
shared_output, final_hidden_states = final_hidden_states
else:
shared_output = None
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_size)
......
......@@ -358,7 +358,6 @@ class BailingMoeV25(nn.Module):
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
......@@ -368,6 +367,8 @@ class BailingMoeV25(nn.Module):
topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk,
router_logits_dtype=self.router_dtype,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scale_to_output=True,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......@@ -383,22 +384,6 @@ class BailingMoeV25(nn.Module):
hidden_states=hidden_states, router_logits=router_logits
)
# Handle tuple return from SharedFusedMoE
if self.shared_experts is not None:
shared_output, final_hidden_states = final_hidden_states
else:
shared_output = None
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_size)
......
......@@ -85,7 +85,6 @@ class DbrxExperts(FusedMoE):
hidden_size=config.d_model,
intermediate_size=config.ffn_config.ffn_hidden_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=get_tensor_model_parallel_world_size(),
......
......@@ -318,7 +318,6 @@ class DeepseekV2MoE(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
......@@ -326,11 +325,9 @@ class DeepseekV2MoE(nn.Module):
topk_group=getattr(config, "topk_group", 1),
prefix=f"{prefix}.experts",
scoring_func=getattr(config, "scoring_func", "softmax"),
# we do scaling outside, set factor to 1.0 to avoid double mul
# aiter applies routed_scaling_factor internally
routed_scaling_factor=1.0
if not self.is_rocm_aiter_moe_enabled
else self.routed_scaling_factor,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scale_to_output=not self.is_rocm_aiter_moe_enabled,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
......@@ -363,43 +360,20 @@ class DeepseekV2MoE(nn.Module):
hidden_states = sequence_parallel_chunk(hidden_states)
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=hidden_states
)
else:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
shared_output, final_hidden_states = fused_moe_out
if self.shared_experts is None:
assert shared_output is None
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
if not self.is_rocm_aiter_moe_enabled:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim)
......
......@@ -37,7 +37,6 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
......@@ -120,7 +119,6 @@ class Dots1MoE(nn.Module):
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
......@@ -163,7 +161,6 @@ class Dots1MoE(nn.Module):
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
......@@ -171,9 +168,9 @@ class Dots1MoE(nn.Module):
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scale_to_output=True,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......@@ -182,16 +179,9 @@ class Dots1MoE(nn.Module):
router_logits, _ = self.gate(hidden_states)
shared_out, routed_out = self.experts(
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self.shared_experts is not None:
final_hidden_states = (routed_out + shared_out) * self.routed_scaling_factor
else:
final_hidden_states = routed_out * self.routed_scaling_factor
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment