Unverified Commit 726efe17 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Move the shared/fused expert output sum into MoERunnerBase (#35949)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 59556265
...@@ -236,7 +236,6 @@ class MoETestConfig: ...@@ -236,7 +236,6 @@ class MoETestConfig:
use_gate: bool use_gate: bool
use_routed_input_transform: bool use_routed_input_transform: bool
enable_eplb: bool = False enable_eplb: bool = False
reduce_results: bool = False
backend: str | None = None backend: str | None = None
ep_size: int = 1 ep_size: int = 1
dp_size: int = 1 dp_size: int = 1
...@@ -295,7 +294,6 @@ def generate_valid_test_configs( ...@@ -295,7 +294,6 @@ def generate_valid_test_configs(
use_shared_experts, use_shared_experts,
use_gate, use_gate,
use_routed_input_transform, use_routed_input_transform,
reduce_results,
) in product( ) in product(
SHAPE_COMBOS, SHAPE_COMBOS,
NUM_EXPERTS, NUM_EXPERTS,
...@@ -304,7 +302,6 @@ def generate_valid_test_configs( ...@@ -304,7 +302,6 @@ def generate_valid_test_configs(
[False, True], # shared [False, True], # shared
[False, True], # gate [False, True], # gate
[False, True], # routed input exform [False, True], # routed input exform
[False, True], # reduce results
): ):
config = MoETestConfig( config = MoETestConfig(
shape[0], # m shape[0], # m
...@@ -318,7 +315,6 @@ def generate_valid_test_configs( ...@@ -318,7 +315,6 @@ def generate_valid_test_configs(
use_gate, use_gate,
use_routed_input_transform, use_routed_input_transform,
enable_eplb, enable_eplb,
reduce_results,
backend, backend,
ep_size, ep_size,
dp_size, dp_size,
...@@ -395,18 +391,7 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]: ...@@ -395,18 +391,7 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]:
and config.backend.startswith("flashinfer_nvlink") and config.backend.startswith("flashinfer_nvlink")
and not current_platform.has_device_capability(90) and not current_platform.has_device_capability(90)
): ):
return False, "flashinfer_nvlink needs an H100+ GPUs" return False, "flashinfer_nvlink needs H100+ GPUs"
# reduce_results incompatibilities
if config.reduce_results and config.use_shared_experts:
return False, "reduce_results=True is not compatible with shared_experts=True"
if config.reduce_results and config.quantization is not None:
return (
False,
"reduce_results=True only tested with unquantized data types in "
"order to limit number of tests run",
)
# Backend-specific checks # Backend-specific checks
if config.backend is not None: if config.backend is not None:
...@@ -448,10 +433,6 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]: ...@@ -448,10 +433,6 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]:
if config.enable_eplb and config.backend not in EPLB_SUPPORTED_BACKENDS: if config.enable_eplb and config.backend not in EPLB_SUPPORTED_BACKENDS:
return False, f"EPLB not supported with {config.backend}." return False, f"EPLB not supported with {config.backend}."
world_size = config.tp_size * config.dp_size
if config.reduce_results and world_size == 1:
return False, "reduce_results=True only makes sense for multi-GPU tests"
if ( if (
config.backend is not None config.backend is not None
and config.backend.startswith("flashinfer_nvlink") and config.backend.startswith("flashinfer_nvlink")
...@@ -846,7 +827,6 @@ def make_fused_moe_layer( ...@@ -846,7 +827,6 @@ def make_fused_moe_layer(
tp_size: int, tp_size: int,
ep_size: int, ep_size: int,
dp_size: int, dp_size: int,
reduce_results: bool,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
top_k: int, top_k: int,
...@@ -874,7 +854,7 @@ def make_fused_moe_layer( ...@@ -874,7 +854,7 @@ def make_fused_moe_layer(
routed_input_transform: torch.nn.Module | None = None, routed_input_transform: torch.nn.Module | None = None,
routed_output_transform: torch.nn.Module | None = None, routed_output_transform: torch.nn.Module | None = None,
pcp_size: int | None = 1, pcp_size: int | None = 1,
) -> tuple[Callable, FusedMoE]: ) -> FusedMoE:
quant_config, qw = make_quant_config(quantization, w1, w2, global_num_experts) quant_config, qw = make_quant_config(quantization, w1, w2, global_num_experts)
kwargs = dict() kwargs = dict()
...@@ -887,8 +867,10 @@ def make_fused_moe_layer( ...@@ -887,8 +867,10 @@ def make_fused_moe_layer(
# Add gate and routed_input_transform if provided # Add gate and routed_input_transform if provided
if gate is not None: if gate is not None:
kwargs["gate"] = gate kwargs["gate"] = gate
if routed_input_transform is not None: if routed_input_transform is not None:
kwargs["routed_input_transform"] = routed_input_transform kwargs["routed_input_transform"] = routed_input_transform
kwargs["routed_output_transform"] = routed_output_transform
layer = builder( layer = builder(
num_experts=global_num_experts, num_experts=global_num_experts,
...@@ -896,7 +878,6 @@ def make_fused_moe_layer( ...@@ -896,7 +878,6 @@ def make_fused_moe_layer(
hidden_size=hidden_size, hidden_size=hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
params_dtype=in_dtype, params_dtype=in_dtype,
reduce_results=reduce_results,
renormalize=renormalize, renormalize=renormalize,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
...@@ -936,36 +917,7 @@ def make_fused_moe_layer( ...@@ -936,36 +917,7 @@ def make_fused_moe_layer(
layer.quant_method.process_weights_after_loading(layer) layer.quant_method.process_weights_after_loading(layer)
def _moe( return layer
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
if shared_experts is None:
final_shared_states = None
final_hidden_states = layer(hidden_states, router_logits)
else:
final_shared_states, final_hidden_states = layer(
hidden_states, router_logits
)
# Apply routed output transform if provided
# (e.g., latent space -> original space)
if routed_output_transform is not None:
final_hidden_states = routed_output_transform(final_hidden_states)
if shared_experts is not None:
assert not reduce_results
assert final_shared_states is not None
final_hidden_states += final_shared_states
if not reduce_results and layer.tp_size > 1:
final_hidden_states = layer.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states
return _moe, layer
def make_fake_moe_layer( def make_fake_moe_layer(
...@@ -999,7 +951,6 @@ def make_fake_moe_layer( ...@@ -999,7 +951,6 @@ def make_fake_moe_layer(
tp_size: int = 1, tp_size: int = 1,
dp_size: int = 1, dp_size: int = 1,
ep_size: int = 1, ep_size: int = 1,
reduce_results: bool = False,
) -> Callable: ) -> Callable:
activation = MoEActivation.from_str(activation) activation = MoEActivation.from_str(activation)
...@@ -1101,7 +1052,7 @@ def make_fake_moe_layer( ...@@ -1101,7 +1052,7 @@ def make_fake_moe_layer(
def _test_body_regular( def _test_body_regular(
moe_fn: Callable, moe_layer: Callable,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
vllm_config: VllmConfig, vllm_config: VllmConfig,
...@@ -1118,13 +1069,12 @@ def _test_body_regular( ...@@ -1118,13 +1069,12 @@ def _test_body_regular(
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
): ):
output = moe_fn(hidden_states, router_logits) output = moe_layer(hidden_states, router_logits)
return baseline_output, output return baseline_output, output
def _test_body_eplb( def _test_body_eplb(
moe_fn: Callable,
moe_layer: FusedMoE, moe_layer: FusedMoE,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -1145,7 +1095,6 @@ def _test_body_eplb( ...@@ -1145,7 +1095,6 @@ def _test_body_eplb(
n: int, n: int,
top_k: int, top_k: int,
shared_experts, shared_experts,
reduce_results: bool,
gate: torch.nn.Module | None, gate: torch.nn.Module | None,
routed_input_transform: torch.nn.Module | None, routed_input_transform: torch.nn.Module | None,
routed_output_transform: torch.nn.Module | None, routed_output_transform: torch.nn.Module | None,
...@@ -1161,7 +1110,7 @@ def _test_body_eplb( ...@@ -1161,7 +1110,7 @@ def _test_body_eplb(
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
): ):
output_before = moe_fn(hidden_states, router_logits) output_before = moe_layer(hidden_states, router_logits)
# Create a fresh FusedMoE layer with enable_eplb=True # Create a fresh FusedMoE layer with enable_eplb=True
# Delete the original layer's registration so the constructor can # Delete the original layer's registration so the constructor can
...@@ -1174,7 +1123,7 @@ def _test_body_eplb( ...@@ -1174,7 +1123,7 @@ def _test_body_eplb(
# When using routed_input_transform, experts operate in latent space # When using routed_input_transform, experts operate in latent space
hidden_size_for_layer = k // 2 if routed_input_transform is not None else k hidden_size_for_layer = k // 2 if routed_input_transform is not None else k
moe_fn, moe_layer = make_fused_moe_layer( eplb_moe_layer = make_fused_moe_layer(
quantization=quantization, quantization=quantization,
use_ep=use_ep, use_ep=use_ep,
hidden_size=hidden_size_for_layer, hidden_size=hidden_size_for_layer,
...@@ -1183,7 +1132,6 @@ def _test_body_eplb( ...@@ -1183,7 +1132,6 @@ def _test_body_eplb(
tp_size=tp_size, tp_size=tp_size,
ep_size=ep_size, ep_size=ep_size,
dp_size=dp_size, dp_size=dp_size,
reduce_results=reduce_results,
w1=w1, w1=w1,
w2=w2, w2=w2,
top_k=top_k, top_k=top_k,
...@@ -1196,14 +1144,14 @@ def _test_body_eplb( ...@@ -1196,14 +1144,14 @@ def _test_body_eplb(
) )
# Necessary? # Necessary?
if moe_layer._expert_map is not None: if eplb_moe_layer._expert_map is not None:
moe_layer._expert_map = moe_layer._expert_map.to(device) eplb_moe_layer._expert_map = eplb_moe_layer._expert_map.to(device)
# All ranks must generate the same permutation # All ranks must generate the same permutation
initial_indices = torch.arange(num_experts, dtype=torch.long) initial_indices = torch.arange(num_experts, dtype=torch.long)
shuffled_indices = initial_indices[torch.randperm(num_experts)] shuffled_indices = initial_indices[torch.randperm(num_experts)]
expert_weights = [list(moe_layer.get_expert_weights())] expert_weights = [list(eplb_moe_layer.get_expert_weights())]
communicator = create_eplb_communicator( communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(), group_coordinator=get_eplb_group(),
...@@ -1227,7 +1175,7 @@ def _test_body_eplb( ...@@ -1227,7 +1175,7 @@ def _test_body_eplb(
num_experts, dtype=torch.int32, device=device num_experts, dtype=torch.int32, device=device
) )
moe_layer.set_eplb_state( eplb_moe_layer.set_eplb_state(
moe_layer_idx=0, moe_layer_idx=0,
expert_load_view=torch.zeros( expert_load_view=torch.zeros(
(1, num_experts), (1, num_experts),
...@@ -1244,7 +1192,7 @@ def _test_body_eplb( ...@@ -1244,7 +1192,7 @@ def _test_body_eplb(
), ),
) )
moe_layer.eplb_state.should_record_tensor = torch.ones( eplb_moe_layer.eplb_state.should_record_tensor = torch.ones(
(), dtype=torch.bool, device=device (), dtype=torch.bool, device=device
) )
...@@ -1255,7 +1203,7 @@ def _test_body_eplb( ...@@ -1255,7 +1203,7 @@ def _test_body_eplb(
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
): ):
output_after = moe_fn(hidden_states, router_logits) output_after = eplb_moe_layer(hidden_states, router_logits)
return output_before, output_after return output_before, output_after
...@@ -1274,7 +1222,6 @@ def _run_one_config( ...@@ -1274,7 +1222,6 @@ def _run_one_config(
num_experts: int, num_experts: int,
top_k: int, top_k: int,
quantization: str | None, quantization: str | None,
reduce_results: bool,
backend: str | None, backend: str | None,
test_body_fn: Callable, test_body_fn: Callable,
use_shared_experts: bool, use_shared_experts: bool,
...@@ -1341,7 +1288,6 @@ def _run_one_config( ...@@ -1341,7 +1288,6 @@ def _run_one_config(
tp_size=tp_size, tp_size=tp_size,
ep_size=ep_size, ep_size=ep_size,
dp_size=dp_size, dp_size=dp_size,
reduce_results=reduce_results,
) )
baseline_output = baseline_layer(hidden_states, router_logits) baseline_output = baseline_layer(hidden_states, router_logits)
...@@ -1369,7 +1315,7 @@ def _run_one_config( ...@@ -1369,7 +1315,7 @@ def _run_one_config(
hidden_size_for_layer = k // 2 if routed_input_transform is not None else k hidden_size_for_layer = k // 2 if routed_input_transform is not None else k
# Create initial MoE layer # Create initial MoE layer
moe_fn, moe_layer = make_fused_moe_layer( moe_layer = make_fused_moe_layer(
quantization=quantization, quantization=quantization,
use_ep=use_ep, use_ep=use_ep,
hidden_size=hidden_size_for_layer, hidden_size=hidden_size_for_layer,
...@@ -1378,7 +1324,6 @@ def _run_one_config( ...@@ -1378,7 +1324,6 @@ def _run_one_config(
tp_size=tp_size, tp_size=tp_size,
ep_size=ep_size, ep_size=ep_size,
dp_size=dp_size, dp_size=dp_size,
reduce_results=reduce_results,
w1=w1, w1=w1,
w2=w2, w2=w2,
top_k=top_k, top_k=top_k,
...@@ -1402,7 +1347,6 @@ def _run_one_config( ...@@ -1402,7 +1347,6 @@ def _run_one_config(
# Call the test body function with all necessary context # Call the test body function with all necessary context
expected, actual = test_body_fn( expected, actual = test_body_fn(
moe_fn=moe_fn,
moe_layer=moe_layer, moe_layer=moe_layer,
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
...@@ -1423,7 +1367,6 @@ def _run_one_config( ...@@ -1423,7 +1367,6 @@ def _run_one_config(
m=m, m=m,
top_k=top_k, top_k=top_k,
shared_experts=shared_experts, shared_experts=shared_experts,
reduce_results=reduce_results,
gate=gate, gate=gate,
routed_input_transform=routed_input_transform, routed_input_transform=routed_input_transform,
routed_output_transform=routed_output_transform, routed_output_transform=routed_output_transform,
...@@ -1520,7 +1463,6 @@ def test_moe_layer_no_parallel( ...@@ -1520,7 +1463,6 @@ def test_moe_layer_no_parallel(
test_config.num_experts, test_config.num_experts,
test_config.top_k, test_config.top_k,
test_config.quantization, test_config.quantization,
test_config.reduce_results,
test_config.backend, test_config.backend,
_test_body_regular, _test_body_regular,
use_shared_experts=test_config.use_shared_experts, use_shared_experts=test_config.use_shared_experts,
...@@ -1578,7 +1520,6 @@ def _parallel_worker( ...@@ -1578,7 +1520,6 @@ def _parallel_worker(
test_config.num_experts, test_config.num_experts,
test_config.top_k, test_config.top_k,
test_config.quantization, test_config.quantization,
test_config.reduce_results,
test_config.backend, test_config.backend,
functools.partial( functools.partial(
_test_body_config, test_config=test_config, cpu_group=cpu_group _test_body_config, test_config=test_config, cpu_group=cpu_group
...@@ -1597,7 +1538,7 @@ def _parallel_worker( ...@@ -1597,7 +1538,7 @@ def _parallel_worker(
failed = failed + 1 failed = failed + 1
if verbosity > 0: if verbosity > 0:
traceback.print_exc() traceback.print_exc()
print(f"\n{str(ex)}\nFAILED {ex.__class__}") print(f"\n{str(ex)}\nFAILED")
else: else:
print("F", end="") print("F", end="")
finally: finally:
......
...@@ -165,7 +165,6 @@ def test_routed_input_transform_inside_vs_outside( ...@@ -165,7 +165,6 @@ def test_routed_input_transform_inside_vs_outside(
top_k=top_k, top_k=top_k,
hidden_size=latent_size, hidden_size=latent_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
reduce_results=False,
renormalize=True, renormalize=True,
params_dtype=dtype, params_dtype=dtype,
tp_size=1, tp_size=1,
...@@ -183,7 +182,6 @@ def test_routed_input_transform_inside_vs_outside( ...@@ -183,7 +182,6 @@ def test_routed_input_transform_inside_vs_outside(
top_k=top_k, top_k=top_k,
hidden_size=latent_size, hidden_size=latent_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
reduce_results=False,
renormalize=True, renormalize=True,
params_dtype=dtype, params_dtype=dtype,
tp_size=1, tp_size=1,
...@@ -212,34 +210,20 @@ def test_routed_input_transform_inside_vs_outside( ...@@ -212,34 +210,20 @@ def test_routed_input_transform_inside_vs_outside(
hidden_states = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) hidden_states = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype)
router_logits = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype) router_logits = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
# Clone inputs so any in-place modification by Method A
# cannot affect Method B's computation.
hidden_states_A = hidden_states.clone()
router_logits_A = router_logits.clone()
with set_forward_context(None, vllm_config, num_tokens=num_tokens): with set_forward_context(None, vllm_config, num_tokens=num_tokens):
shared_out_A, routed_out_A = moe_with_transform( # Method A: combined output (shared + routed)
hidden_states_A, router_logits_A combined_A = moe_with_transform(hidden_states, router_logits)
)
# Method B: manually transform, get routed output, add shared
transformed_hidden = routed_transform(hidden_states) transformed_hidden = routed_transform(hidden_states)
shared_out_B, routed_out_B = moe_without_transform( routed_out_B = moe_without_transform(transformed_hidden, router_logits)
transformed_hidden, router_logits shared_out_B = shared_experts(hidden_states)
) combined_B = shared_out_B + routed_out_B
expected_shared_out = shared_experts(hidden_states)
_assert_close( torch.testing.assert_close(
routed_out_A, combined_A,
routed_out_B, combined_B,
atol=1e-3,
rtol=1e-3,
label="Routed output: transform inside vs outside",
)
_assert_close(
shared_out_A,
expected_shared_out,
atol=1e-3, atol=1e-3,
rtol=1e-3, rtol=1e-3,
label="Shared expert output", msg="Combined output should match: transform inside vs outside",
) )
...@@ -592,9 +592,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -592,9 +592,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs) return self.base_layer.forward(*args, **kwargs)
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
@property @property
def quant_method(self): def quant_method(self):
return self.base_layer.quant_method return self.base_layer.quant_method
......
...@@ -716,7 +716,7 @@ class MarlinExperts(MarlinExpertsBase): ...@@ -716,7 +716,7 @@ class MarlinExperts(MarlinExpertsBase):
): ):
assert self.w1_scale is not None assert self.w1_scale is not None
assert self.w2_scale is not None assert self.w2_scale is not None
return fused_marlin_moe( fused_marlin_moe(
hidden_states=hidden_states, hidden_states=hidden_states,
w1=w1, w1=w1,
w2=w2, w2=w2,
......
...@@ -230,11 +230,18 @@ class FusedMoE(PluggableLayer): ...@@ -230,11 +230,18 @@ class FusedMoE(PluggableLayer):
hidden_size: Input hidden state size of the transformer hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
reduce_results: Whether to all_reduce on the output of the layer
renormalize: Whether to renormalize the logits in the fused_moe kernel renormalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure. quant_config: Quantization configure.
enable_eplb: Whether to enable expert parallelism load balancer. enable_eplb: Whether to enable expert parallelism load balancer.
router_logits_dtype: Data type for router logits buffers. router_logits_dtype: Data type for router logits buffers.
routed_scaling_factor: A scaling factor that is applied to the topk_weights
by the router or the output of the layer depending
on the value of `apply_routed_scale_to_output`
apply_routed_scale_to_output: Determine whether or not `routed_scaling_factor`
is applied to the topk_weights or to the experts
output. It is applied to the experts output
instead of the topk_weights when this feature is
not supported by the router (or the experts).
""" """
# --8<-- [end:fused_moe] # --8<-- [end:fused_moe]
...@@ -246,7 +253,6 @@ class FusedMoE(PluggableLayer): ...@@ -246,7 +253,6 @@ class FusedMoE(PluggableLayer):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
params_dtype: torch.dtype | None = None, params_dtype: torch.dtype | None = None,
reduce_results: bool = False,
renormalize: bool = True, renormalize: bool = True,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: int | None = None, num_expert_group: int | None = None,
...@@ -274,12 +280,12 @@ class FusedMoE(PluggableLayer): ...@@ -274,12 +280,12 @@ class FusedMoE(PluggableLayer):
gate: torch.nn.Module | None = None, gate: torch.nn.Module | None = None,
shared_experts: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None,
routed_input_transform: torch.nn.Module | None = None, routed_input_transform: torch.nn.Module | None = None,
routed_output_transform: torch.nn.Module | None = None,
apply_routed_scale_to_output: bool = False,
zero_expert_type: str | None = None, zero_expert_type: str | None = None,
): ):
super().__init__() super().__init__()
self._routed_input_transform = routed_input_transform
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
...@@ -425,7 +431,6 @@ class FusedMoE(PluggableLayer): ...@@ -425,7 +431,6 @@ class FusedMoE(PluggableLayer):
assert intermediate_size % self.tp_size == 0 assert intermediate_size % self.tp_size == 0
intermediate_size_per_partition = intermediate_size // self.tp_size intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize self.renormalize = renormalize
# TODO(bnell): these attributes are only used by monolithic kernels. # TODO(bnell): these attributes are only used by monolithic kernels.
...@@ -437,7 +442,14 @@ class FusedMoE(PluggableLayer): ...@@ -437,7 +442,14 @@ class FusedMoE(PluggableLayer):
self.topk_group = topk_group self.topk_group = topk_group
self.custom_routing_function = custom_routing_function self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor # When apply_routed_scale_to_output is True, we set the scaling factor
# to 1.0 so it ends up being a nop. Applying the scale will be handled
# by the runner in this case.
# The member variable must be set in the same way as the router since
# some quantization methods can access it.
self.routed_scaling_factor = (
routed_scaling_factor if not apply_routed_scale_to_output else 1.0
)
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
# TODO(bnell): end attributes # TODO(bnell): end attributes
...@@ -456,7 +468,7 @@ class FusedMoE(PluggableLayer): ...@@ -456,7 +468,7 @@ class FusedMoE(PluggableLayer):
topk_group=topk_group, topk_group=topk_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=self.num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
...@@ -578,12 +590,18 @@ class FusedMoE(PluggableLayer): ...@@ -578,12 +590,18 @@ class FusedMoE(PluggableLayer):
layer_name=self.layer_name, layer_name=self.layer_name,
moe_config=self.moe_config, moe_config=self.moe_config,
router=self.router, router=self.router,
routed_input_transform=self._routed_input_transform,
gate=gate, gate=gate,
shared_experts=shared_experts, shared_experts=shared_experts,
quant_method=self.quant_method, quant_method=self.quant_method,
reduce_results=self.reduce_results,
enable_dbo=self.vllm_config.parallel_config.enable_dbo, enable_dbo=self.vllm_config.parallel_config.enable_dbo,
routed_input_transform=routed_input_transform,
routed_output_transform=routed_output_transform,
# When apply_routed_scale_to_output is True, we allow
# the scaling factor to be passed to the runner, otherwise
# we pass 1.0 so it ends up being a nop.
routed_scaling_factor=routed_scaling_factor
if apply_routed_scale_to_output
else 1.0,
) )
# TODO(bnell): This method is provided as a hook so vllm/lora/layers/fused_moe.py # TODO(bnell): This method is provided as a hook so vllm/lora/layers/fused_moe.py
...@@ -1514,32 +1532,11 @@ class FusedMoE(PluggableLayer): ...@@ -1514,32 +1532,11 @@ class FusedMoE(PluggableLayer):
self.ensure_moe_quant_config_init() self.ensure_moe_quant_config_init()
return self.quant_method.moe_quant_config return self.quant_method.moe_quant_config
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
layer. The result of this function is typically used as
the reduce_results argument to the module.
When just tensor-parallel is used, it is not required to reduce
the shared_experts results immediately. Instead we reduce at the
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
With EP and all2all kernels - this is no longer viable as all
GPU ranks in DP, produce the complete set of hidden_states.
Therefore it is required that we reduce the shared_experts output
early.
"""
return self.runner.must_reduce_shared_expert_outputs()
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
"""
Some combine kernels reduce across GPU ranks by default.
"""
return self.runner.maybe_all_reduce_tensor_model_parallel(final_hidden_states)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
return self.runner.forward( return self.runner.forward(
hidden_states, hidden_states,
router_logits, router_logits,
...@@ -1613,7 +1610,6 @@ class FusedMoE(PluggableLayer): ...@@ -1613,7 +1610,6 @@ class FusedMoE(PluggableLayer):
f"intermediate_size_per_partition={self.intermediate_size_per_partition}, " # noqa: E501 f"intermediate_size_per_partition={self.intermediate_size_per_partition}, " # noqa: E501
f"tp_size={self.tp_size},\n" f"tp_size={self.tp_size},\n"
f"ep_size={self.ep_size}, " f"ep_size={self.ep_size}, "
f"reduce_results={self.reduce_results}, "
) )
return s return s
......
...@@ -1261,7 +1261,7 @@ class FusedMoEKernelModularImpl: ...@@ -1261,7 +1261,7 @@ class FusedMoEKernelModularImpl:
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
""" """
The _finalize method is a wrapper around self.prepare_finalize.finalize The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap. that handles DBO, async and shared expert overlap.
......
...@@ -37,10 +37,6 @@ class DefaultMoERunner(MoERunnerBase): ...@@ -37,10 +37,6 @@ class DefaultMoERunner(MoERunnerBase):
for different configurations (e.g., with/without shared experts, gates, etc.). for different configurations (e.g., with/without shared experts, gates, etc.).
""" """
@property
def reduce_results(self) -> bool:
return self._reduce_results
@property @property
def do_naive_dispatch_combine(self) -> bool: def do_naive_dispatch_combine(self) -> bool:
return ( return (
......
...@@ -26,18 +26,7 @@ class MoERunner(ABC): ...@@ -26,18 +26,7 @@ class MoERunner(ABC):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def must_reduce_shared_expert_outputs(self) -> bool:
raise NotImplementedError
@abstractmethod
def maybe_all_reduce_tensor_model_parallel(
self,
final_hidden_states: torch.Tensor,
):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
......
...@@ -81,7 +81,9 @@ def _resolve_layer_name(layer_name: str | LayerName) -> str: ...@@ -81,7 +81,9 @@ def _resolve_layer_name(layer_name: str | LayerName) -> str:
# Note: _moe_forward and _moe_forward_shared should not contain any # Note: _moe_forward and _moe_forward_shared should not contain any
# implementation details, They should merely pass along control to # implementation details, They should merely pass along control to
# the runner's 'forward_dispatch' method. # the runner's '_forward_dispatch' method.
# These functions should never be called directly since they do not
# include all the functionality of the MoE layer.
def _moe_forward( def _moe_forward(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -89,7 +91,7 @@ def _moe_forward( ...@@ -89,7 +91,7 @@ def _moe_forward(
layer_name: _layer_name_type, layer_name: _layer_name_type,
) -> torch.Tensor: ) -> torch.Tensor:
layer = get_layer_from_name(_resolve_layer_name(layer_name)) layer = get_layer_from_name(_resolve_layer_name(layer_name))
return layer.runner.forward_dispatch( return layer.runner._forward_dispatch(
layer, layer,
hidden_states, hidden_states,
router_logits, router_logits,
...@@ -113,7 +115,7 @@ def _moe_forward_shared( ...@@ -113,7 +115,7 @@ def _moe_forward_shared(
layer_name: _layer_name_type, layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
layer = get_layer_from_name(_resolve_layer_name(layer_name)) layer = get_layer_from_name(_resolve_layer_name(layer_name))
return layer.runner.forward_dispatch( return layer.runner._forward_dispatch(
layer, layer,
hidden_states, hidden_states,
router_logits, router_logits,
...@@ -143,7 +145,7 @@ def _moe_forward_shared_fake( ...@@ -143,7 +145,7 @@ def _moe_forward_shared_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="moe_forward", op_name="moe_forward",
op_func=_moe_forward, op_func=_moe_forward,
mutates_args=["hidden_states"], # is this still true? mutates_args=["hidden_states"],
fake_impl=_moe_forward_fake, fake_impl=_moe_forward_fake,
tags=(torch.Tag.needs_fixed_stride_order,), tags=(torch.Tag.needs_fixed_stride_order,),
) )
...@@ -157,6 +159,15 @@ direct_register_custom_op( ...@@ -157,6 +159,15 @@ direct_register_custom_op(
) )
def _unpack(
result: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
) -> tuple[torch.Tensor | None, torch.Tensor]:
if isinstance(result, tuple):
return result
else:
return (None, result)
class MoERunnerBase(MoERunner): class MoERunnerBase(MoERunner):
""" """
Abstract base class providing common functionality for MoE runner implementations. Abstract base class providing common functionality for MoE runner implementations.
...@@ -174,7 +185,6 @@ class MoERunnerBase(MoERunner): ...@@ -174,7 +185,6 @@ class MoERunnerBase(MoERunner):
allowing flexibility in the actual MoE computation implementation. allowing flexibility in the actual MoE computation implementation.
Key abstract methods that subclasses must implement: Key abstract methods that subclasses must implement:
- reduce_results: Determines whether results should be reduced across ranks
- _forward_impl: The core MoE computation logic specific to each runner type - _forward_impl: The core MoE computation logic specific to each runner type
""" """
...@@ -187,17 +197,23 @@ class MoERunnerBase(MoERunner): ...@@ -187,17 +197,23 @@ class MoERunnerBase(MoERunner):
gate: torch.nn.Module | None, gate: torch.nn.Module | None,
shared_experts: torch.nn.Module | None, shared_experts: torch.nn.Module | None,
quant_method: FusedMoEMethodBase, quant_method: FusedMoEMethodBase,
reduce_results: bool,
enable_dbo: bool, enable_dbo: bool,
routed_output_transform: torch.nn.Module | None = None,
routed_scaling_factor: float = 1.0,
): ):
super().__init__() super().__init__()
self.moe_config = moe_config self.moe_config = moe_config
self.router = router self.router = router
self.routed_input_transform = routed_input_transform self.routed_input_transform = routed_input_transform
self.routed_output_transform = routed_output_transform
self.routed_scaling_factor = routed_scaling_factor
self.gate = gate self.gate = gate
self.quant_method = quant_method self.quant_method = quant_method
self._reduce_results = reduce_results
self.enable_dbo = enable_dbo self.enable_dbo = enable_dbo
self._fused_output_is_reduced = (
self.quant_method.moe_kernel is not None
and self.quant_method.moe_kernel.output_is_reduced()
)
self._shared_experts: SharedExperts | None = None self._shared_experts: SharedExperts | None = None
if shared_experts is not None: if shared_experts is not None:
...@@ -209,7 +225,6 @@ class MoERunnerBase(MoERunner): ...@@ -209,7 +225,6 @@ class MoERunnerBase(MoERunner):
# called, i.e. by a MK or by the MoERunner. # called, i.e. by a MK or by the MoERunner.
# Once the MK can be created upfront, we can just pass in the proper # Once the MK can be created upfront, we can just pass in the proper
# flags derived from the quant_method's MK. # flags derived from the quant_method's MK.
reduce_results=reduce_results,
quant_method=quant_method, quant_method=quant_method,
enable_dbo=enable_dbo, enable_dbo=enable_dbo,
) )
...@@ -217,7 +232,7 @@ class MoERunnerBase(MoERunner): ...@@ -217,7 +232,7 @@ class MoERunnerBase(MoERunner):
# Needed for string -> FusedMoE layer lookup in custom ops. # Needed for string -> FusedMoE layer lookup in custom ops.
self.layer_name = layer_name self.layer_name = layer_name
self.forward_entry = self._select_forward() self._forward_entry = self._select_forward()
def _select_forward(self) -> Callable: def _select_forward(self) -> Callable:
if current_platform.is_tpu() or current_platform.is_cpu(): if current_platform.is_tpu() or current_platform.is_cpu():
...@@ -245,38 +260,6 @@ class MoERunnerBase(MoERunner): ...@@ -245,38 +260,6 @@ class MoERunnerBase(MoERunner):
def is_internal_router(self) -> bool: def is_internal_router(self) -> bool:
return self.gate is not None return self.gate is not None
@property
@abstractmethod
def reduce_results(self) -> bool:
raise NotImplementedError
def must_reduce_shared_expert_outputs(self) -> bool:
"""
The shared_experts are typically computed using the RowParallelLinear
layer. The result of this function is typically used as
the reduce_results argument to the module.
When just tensor-parallel is used, it is not required to reduce
the shared_experts results immediately. Instead we reduce at the
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
With EP and all2all kernels - this is no longer viable as all
GPU ranks in DP, produce the complete set of hidden_states.
Therefore it is required that we reduce the shared_experts output
early.
"""
return (
self.quant_method.moe_kernel is not None
and self.quant_method.moe_kernel.output_is_reduced()
)
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
"""
Some combine kernels reduce across GPU ranks by default.
"""
if self.must_reduce_shared_expert_outputs():
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
def apply_routed_input_transform( def apply_routed_input_transform(
self, hidden_states: torch.Tensor self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
...@@ -286,10 +269,6 @@ class MoERunnerBase(MoERunner): ...@@ -286,10 +269,6 @@ class MoERunnerBase(MoERunner):
is saved separately so shared experts get [S, hidden_size] while is saved separately so shared experts get [S, hidden_size] while
routed experts get the transformed [S, moe_latent_size]. routed experts get the transformed [S, moe_latent_size].
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
moved inside SharedFusedMoE to all-reduce on the smaller latent
dimension.
Returns (possibly transformed) hidden states and the input for shared Returns (possibly transformed) hidden states and the input for shared
experts (or None if there are no shared experts). experts (or None if there are no shared experts).
""" """
...@@ -306,33 +285,79 @@ class MoERunnerBase(MoERunner): ...@@ -306,33 +285,79 @@ class MoERunnerBase(MoERunner):
hidden_states if self._shared_experts is not None else None, hidden_states if self._shared_experts is not None else None,
) )
def _maybe_reduce_output( def apply_routed_output_transform(
self, self,
states: torch.Tensor | tuple[torch.Tensor, torch.Tensor], fused_output: torch.Tensor,
trunc_sizes: list[int], ) -> torch.Tensor:
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Apply transform to routed expert output (e.g., latent to full dim).
def trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
return x[..., :trunc_size] Used by latent MoE models (e.g., NemotronH) where routed experts
operate in a compressed latent space and need projection back to
the full hidden dimension before combining with shared expert output.
"""
if self.routed_output_transform is not None:
r = self.routed_output_transform(fused_output)
fused_output = r[0] if isinstance(r, tuple) else r
return fused_output
def reduce_and_trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor: def _maybe_apply_routed_scale_to_output(
return trunc(self.maybe_all_reduce_tensor_model_parallel(x), trunc_size) self,
shared_output: torch.Tensor | None,
fused_output: torch.Tensor,
) -> tuple[torch.Tensor | None, torch.Tensor]:
"""Apply routed_scaling_factor to the output with FP16 overflow
protection.
Scale the fused expert output by routed_scaling_factor. For FP16,
avoid overflow by dividing shared_output by the scale instead
(the decoder layer compensates with matching divisions).
"""
if self.routed_scaling_factor != 1.0:
if fused_output.dtype != torch.float16:
fused_output *= self.routed_scaling_factor
elif shared_output is not None:
shared_output *= 1.0 / self.routed_scaling_factor
return shared_output, fused_output
def _maybe_reduce_shared_expert_output(
self,
shared_output: torch.Tensor | None,
) -> torch.Tensor | None:
"""All-reduce shared expert output when the combine kernel already
reduced fused output.
This is the "early" all-reduce path. When the combine kernel produces
already-reduced fused output, shared output must be reduced separately
to match.
"""
if self._fused_output_is_reduced:
assert shared_output is not None
shared_output = tensor_model_parallel_all_reduce(shared_output)
return shared_output
def _maybe_reduce_final_output(
self,
states: torch.Tensor,
trunc_size: int,
) -> torch.Tensor:
"""Truncate padded dimensions and all-reduce the combined output.
This is the "late" all-reduce path. When neither fused nor shared
output was individually reduced, the combined sum is all-reduced
here. Skipped when sequence-parallel is active (SP handles its
own reduction) or when the early path already reduced both outputs.
"""
# We don't need to reduce the final output if:
# - We are not running with TP or DP
# - The MK already reduced the fused output itself.
if ( if (
not self.moe_config.is_sequence_parallel not self.moe_config.is_sequence_parallel
and self.reduce_results
and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1) and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1)
and not self._fused_output_is_reduced
): ):
func = reduce_and_trunc states = tensor_model_parallel_all_reduce(states)
else:
func = trunc
if isinstance(states, tuple): return states[..., :trunc_size]
return tuple(
[func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
)
else:
assert len(trunc_sizes) == 1
return func(states, trunc_sizes[0])
def _encode_layer_name(self) -> str | LayerName: def _encode_layer_name(self) -> str | LayerName:
if _USE_LAYERNAME: if _USE_LAYERNAME:
...@@ -349,7 +374,15 @@ class MoERunnerBase(MoERunner): ...@@ -349,7 +374,15 @@ class MoERunnerBase(MoERunner):
self, self,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, list[int]]: ) -> tuple[torch.Tensor, int]:
"""Pad hidden_states to moe_config.hidden_dim and compute the
original dimension for later truncation.
For latent MoE, the routed hidden_states may be smaller than
hidden_dim. Padding ensures uniform tensor sizes through the
fused MoE kernel. The returned trunc_size is used by
_maybe_reduce_final_output to strip the padding from the result.
"""
shared_experts_hidden_dim = ( shared_experts_hidden_dim = (
shared_experts_input.shape[-1] if shared_experts_input is not None else 0 shared_experts_input.shape[-1] if shared_experts_input is not None else 0
) )
...@@ -365,10 +398,10 @@ class MoERunnerBase(MoERunner): ...@@ -365,10 +398,10 @@ class MoERunnerBase(MoERunner):
value=0.0, value=0.0,
) )
if self._shared_experts is not None: if self.routed_output_transform is not None and shared_experts_hidden_dim > 0:
orig_hidden_dims = [shared_experts_hidden_dim, transformed_hidden_dim] orig_hidden_dims = shared_experts_hidden_dim
else: else:
orig_hidden_dims = [transformed_hidden_dim] orig_hidden_dims = transformed_hidden_dim
return hidden_states, orig_hidden_dims return hidden_states, orig_hidden_dims
...@@ -388,6 +421,12 @@ class MoERunnerBase(MoERunner): ...@@ -388,6 +421,12 @@ class MoERunnerBase(MoERunner):
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> tuple[torch.Tensor | None, torch.Tensor]: ) -> tuple[torch.Tensor | None, torch.Tensor]:
"""Run expert routing and the fused MoE kernel via the quant method.
Orchestrates shared expert execution (before/after), expert selection
via the router, and the actual fused MoE computation. Returns
(shared_expert_output, fused_expert_output).
"""
# Run this before quant_method to avoid inplace issues. # Run this before quant_method to avoid inplace issues.
# TODO(bnell): probably not needed anymore since inplace is # TODO(bnell): probably not needed anymore since inplace is
# disabled when shared experts are present. # disabled when shared experts are present.
...@@ -428,6 +467,13 @@ class MoERunnerBase(MoERunner): ...@@ -428,6 +467,13 @@ class MoERunnerBase(MoERunner):
) )
def _sequence_parallel_context(self): def _sequence_parallel_context(self):
"""Return a context manager for sequence-parallel token
redistribution.
When sequence parallelism is active, returns a context that handles
local size tracking for proper token scatter/gather. Otherwise
returns a no-op context.
"""
ctx = get_forward_context() ctx = get_forward_context()
return ( return (
ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size) ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
...@@ -448,22 +494,25 @@ class MoERunnerBase(MoERunner): ...@@ -448,22 +494,25 @@ class MoERunnerBase(MoERunner):
def _maybe_add_zero_expert_output( def _maybe_add_zero_expert_output(
self, self,
result: torch.Tensor | tuple[torch.Tensor, torch.Tensor], result: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
"""Add the zero expert's contribution to the final result.
When a ZeroExpertRouter is used, it computes a bias-like output
from the "zero expert" that is added to the combined routed+shared
expert output.
"""
if isinstance(self.router, ZeroExpertRouter): if isinstance(self.router, ZeroExpertRouter):
zero_expert_output = self.router.zero_expert_output zero_expert_output = self.router.zero_expert_output
assert zero_expert_output is not None assert zero_expert_output is not None
if isinstance(result, tuple): result = result + zero_expert_output
result = (result[0], result[1] + zero_expert_output)
else:
result = result + zero_expert_output
return result return result
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
"""Invoke the fused moe layer. """Invoke the fused moe layer.
Input: Input:
...@@ -472,55 +521,88 @@ class MoERunnerBase(MoERunner): ...@@ -472,55 +521,88 @@ class MoERunnerBase(MoERunner):
Output: Output:
- The new hidden_states. - The new hidden_states.
or
- A tuple of (shared experts output, new hidden_states).
Calling sequence Calling sequence
- forward - forward
- self.forward_entry (_moe_forward or _moe_forward_shared custom op) - self._forward_entry (_moe_forward or _moe_forward_shared custom op)
- forward_dispatch - _forward_dispatch
- _forward_impl - _forward_impl
Note: The existence of _moe_forward and _moe_forward_shared custom ops are due Note: The existence of _moe_forward and _moe_forward_shared custom ops are due
to the following reasons: to the following reasons:
1. the chunking loop in ChunkingMoERunner._forward_impl cannot be compiled by 1. the chunking loop in ChunkingMoERunner._forward_impl cannot be compiled by
torch.compile torch.compile
2. pytorch cannot handle union types in custom op signatures so _moe_forward 2. pytorch cannot handle union types in custom op signatures so
and _moe_forward_shared must be split. _moe_forward and _moe_forward_shared must be split.
If ChunkingMoERunner._forward_impl can be implemented via torch.scan we can If ChunkingMoERunner._forward_impl can be implemented via torch.scan we can
potentially get rid of _moe_forward and _moe_forward_shared and collapse the potentially get rid of _moe_forward and _moe_forward_shared and collapse the
whole sequence into the 'forward' method. whole sequence into the 'forward' method.
""" """
# Apply transform for routed experts (e.g., latent projection for latent MoE) # Apply transform for routed experts (e.g., latent projection
# for latent MoE)
hidden_states, shared_experts_input = self.apply_routed_input_transform( hidden_states, shared_experts_input = self.apply_routed_input_transform(
hidden_states hidden_states
) )
hidden_states, og_hidden_dims = self._maybe_pad_hidden_states( hidden_states, og_hidden_dim = self._maybe_pad_hidden_states(
shared_experts_input, shared_experts_input,
hidden_states, hidden_states,
) )
fused_output = self.forward_entry( result = self._forward_entry(
hidden_states, hidden_states,
router_logits, router_logits,
shared_experts_input, shared_experts_input,
self._encode_layer_name(), self._encode_layer_name(),
) )
result = self._maybe_reduce_output(fused_output, og_hidden_dims) #
# Note: there are two all-reduce points below. They are mutually
# exclusive, controlled by _fused_output_is_reduced
# - When True: the combine kernel already reduced fused_output,
# so we reduce shared_output here to match, then skip the
# all-reduce in _maybe_reduce_final_output.
# - When False: neither output is reduced yet, so we combine
# them first and all-reduce the sum in _maybe_reduce_final_output.
# Extract outputs from result
shared_output, fused_output = _unpack(result)
# If combine kernel already reduced fused, reduce shared to match.
# See note above re: the two all-reduce points.
shared_output = self._maybe_reduce_shared_expert_output(shared_output)
shared_output, fused_output = self._maybe_apply_routed_scale_to_output(
shared_output, fused_output
)
# Apply output transform (e.g. latent -> full dim)
fused_output = self.apply_routed_output_transform(fused_output)
if shared_output is not None:
result = shared_output + fused_output
else:
result = fused_output
result = self._maybe_reduce_final_output(result, og_hidden_dim)
return self._maybe_add_zero_expert_output(result) return self._maybe_add_zero_expert_output(result)
def forward_dispatch( def _forward_dispatch(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Entry point called by the custom op to run the MoE computation.
Handles pre-dispatch setup (gate application, external shared expert
triggering, quant config init) then delegates to _forward_impl within
the sequence-parallel context.
"""
# TODO(bnell): this can be removed after MK migration is complete. # TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init() layer.ensure_moe_quant_config_init()
...@@ -549,4 +631,11 @@ class MoERunnerBase(MoERunner): ...@@ -549,4 +631,11 @@ class MoERunnerBase(MoERunner):
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Core MoE computation to be implemented by subclasses.
Performs expert routing, fused MoE kernel execution, and shared
expert computation. Returns a single tensor (fused output only)
or a tuple of (shared_output, fused_output) when shared experts
are present.
"""
raise NotImplementedError raise NotImplementedError
...@@ -29,8 +29,9 @@ def create_moe_runner( ...@@ -29,8 +29,9 @@ def create_moe_runner(
gate: torch.nn.Module | None, gate: torch.nn.Module | None,
shared_experts: SharedExperts | None, shared_experts: SharedExperts | None,
quant_method: FusedMoEMethodBase, quant_method: FusedMoEMethodBase,
reduce_results: bool,
enable_dbo: bool, enable_dbo: bool,
routed_output_transform: torch.nn.Module | None = None,
routed_scaling_factor: float = 1.0,
) -> MoERunner: ) -> MoERunner:
return DefaultMoERunner( return DefaultMoERunner(
layer_name, layer_name,
...@@ -40,6 +41,7 @@ def create_moe_runner( ...@@ -40,6 +41,7 @@ def create_moe_runner(
gate, gate,
shared_experts, shared_experts,
quant_method, quant_method,
reduce_results,
enable_dbo, enable_dbo,
routed_output_transform=routed_output_transform,
routed_scaling_factor=routed_scaling_factor,
) )
...@@ -5,10 +5,6 @@ from enum import IntEnum ...@@ -5,10 +5,6 @@ from enum import IntEnum
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
...@@ -48,7 +44,6 @@ class SharedExperts: ...@@ -48,7 +44,6 @@ class SharedExperts:
layer: torch.nn.Module, layer: torch.nn.Module,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_method: QuantizeMethodBase, quant_method: QuantizeMethodBase,
reduce_results: bool,
enable_dbo: bool, enable_dbo: bool,
): ):
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
...@@ -68,7 +63,6 @@ class SharedExperts: ...@@ -68,7 +63,6 @@ class SharedExperts:
self._layer = layer self._layer = layer
self._moe_config = moe_config self._moe_config = moe_config
self._quant_method = quant_method self._quant_method = quant_method
self._reduce_results = reduce_results
# Allow disabling of the separate shared experts stream for # Allow disabling of the separate shared experts stream for
# debug purposes. # debug purposes.
...@@ -139,18 +133,6 @@ class SharedExperts: ...@@ -139,18 +133,6 @@ class SharedExperts:
return output return output
def _maybe_reduce_shared_out(self, shared_out: torch.Tensor) -> torch.Tensor:
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
if (
self._reduce_results
and self._quant_method.moe_kernel is not None
and self._quant_method.moe_kernel.output_is_reduced()
and get_tensor_model_parallel_world_size() > 1
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out
@property @property
def _output_idx(self) -> int: def _output_idx(self) -> int:
return dbo_current_ubatch_id() if self.enable_dbo else 0 return dbo_current_ubatch_id() if self.enable_dbo else 0
......
...@@ -18,12 +18,8 @@ class SharedFusedMoE(FusedMoE): ...@@ -18,12 +18,8 @@ class SharedFusedMoE(FusedMoE):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
result = super().forward( return super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
) )
if self.shared_experts is None:
return None, result
else:
return result
...@@ -100,7 +100,7 @@ class AXK1MoE(nn.Module): ...@@ -100,7 +100,7 @@ class AXK1MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = get_ep_group().rank_in_group self.ep_rank = get_ep_group().rank_in_group
...@@ -170,7 +170,6 @@ class AXK1MoE(nn.Module): ...@@ -170,7 +170,6 @@ class AXK1MoE(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
use_grouped_topk=True, use_grouped_topk=True,
...@@ -180,9 +179,8 @@ class AXK1MoE(nn.Module): ...@@ -180,9 +179,8 @@ class AXK1MoE(nn.Module):
scoring_func=config.scoring_func, scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul # we do scaling outside, set factor to 1.0 to avoid double mul
# aiter applies routed_scaling_factor internally # aiter applies routed_scaling_factor internally
routed_scaling_factor=1.0 routed_scaling_factor=self.routed_scaling_factor,
if not self.is_rocm_aiter_moe_enabled apply_routed_scale_to_output=not self.is_rocm_aiter_moe_enabled,
else self.routed_scaling_factor,
e_score_correction_bias=self.gate.e_score_correction_bias, e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
...@@ -204,43 +202,20 @@ class AXK1MoE(nn.Module): ...@@ -204,43 +202,20 @@ class AXK1MoE(nn.Module):
hidden_states = sequence_parallel_chunk(hidden_states) hidden_states = sequence_parallel_chunk(hidden_states)
if self.experts.is_internal_router: if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class final_hidden_states = self.experts(
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states hidden_states=hidden_states, router_logits=hidden_states
) )
else: else:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
shared_output, final_hidden_states = fused_moe_out
if self.shared_experts is None:
assert shared_output is None
# Fix FP16 overflow
# See AXK1DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
if not self.is_rocm_aiter_moe_enabled:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel: if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0 final_hidden_states, 0
) )
final_hidden_states = final_hidden_states[:num_tokens] final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
......
...@@ -131,7 +131,6 @@ class AfmoeMoE(nn.Module): ...@@ -131,7 +131,6 @@ class AfmoeMoE(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.route_norm if self.score_func == "sigmoid" else False, renormalize=self.route_norm if self.score_func == "sigmoid" else False,
quant_config=quant_config, quant_config=quant_config,
use_grouped_topk=True, use_grouped_topk=True,
...@@ -152,20 +151,10 @@ class AfmoeMoE(nn.Module): ...@@ -152,20 +151,10 @@ class AfmoeMoE(nn.Module):
router_logits = self.gate(hidden_states.to(dtype=torch.float32)) router_logits = self.gate(hidden_states.to(dtype=torch.float32))
fused_moe_out = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
if self.shared_experts is not None:
shared_output, final_hidden_states = fused_moe_out
final_hidden_states = final_hidden_states + shared_output
else:
final_hidden_states = fused_moe_out
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
......
...@@ -283,7 +283,6 @@ class AriaTextMoELayer(nn.Module): ...@@ -283,7 +283,6 @@ class AriaTextMoELayer(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
reduce_results=True,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
) )
...@@ -301,12 +300,7 @@ class AriaTextMoELayer(nn.Module): ...@@ -301,12 +300,7 @@ class AriaTextMoELayer(nn.Module):
router_output = torch.nn.functional.linear(hidden_states, self.router_weight) router_output = torch.nn.functional.linear(hidden_states, self.router_weight)
sparse_expert_output = self.experts(hidden_states, router_output) return self.experts(hidden_states, router_output)
if self.shared_experts is not None:
return sparse_expert_output[0] + sparse_expert_output[1]
else:
return sparse_expert_output
class AriaTextDecoderLayer(LlamaDecoderLayer): class AriaTextDecoderLayer(LlamaDecoderLayer):
......
...@@ -291,7 +291,6 @@ class BailingMoE(nn.Module): ...@@ -291,7 +291,6 @@ class BailingMoE(nn.Module):
top_k=self.top_k, top_k=self.top_k,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob, renormalize=self.norm_expert_prob,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
...@@ -301,6 +300,7 @@ class BailingMoE(nn.Module): ...@@ -301,6 +300,7 @@ class BailingMoE(nn.Module):
topk_group=self.topk_group, topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
router_logits_dtype=self.router_dtype, router_logits_dtype=self.router_dtype,
routed_scaling_factor=self.routed_scaling_factor,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -314,21 +314,6 @@ class BailingMoE(nn.Module): ...@@ -314,21 +314,6 @@ class BailingMoE(nn.Module):
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
if self.shared_experts is not None:
shared_output, final_hidden_states = final_hidden_states
else:
shared_output = None
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_size) return final_hidden_states.view(num_tokens, hidden_size)
......
...@@ -358,7 +358,6 @@ class BailingMoeV25(nn.Module): ...@@ -358,7 +358,6 @@ class BailingMoeV25(nn.Module):
top_k=self.top_k, top_k=self.top_k,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob, renormalize=self.norm_expert_prob,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
...@@ -368,6 +367,8 @@ class BailingMoeV25(nn.Module): ...@@ -368,6 +367,8 @@ class BailingMoeV25(nn.Module):
topk_group=self.topk_group, topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
router_logits_dtype=self.router_dtype, router_logits_dtype=self.router_dtype,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scale_to_output=True,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -383,22 +384,6 @@ class BailingMoeV25(nn.Module): ...@@ -383,22 +384,6 @@ class BailingMoeV25(nn.Module):
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
# Handle tuple return from SharedFusedMoE
if self.shared_experts is not None:
shared_output, final_hidden_states = final_hidden_states
else:
shared_output = None
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_size) return final_hidden_states.view(num_tokens, hidden_size)
......
...@@ -85,7 +85,6 @@ class DbrxExperts(FusedMoE): ...@@ -85,7 +85,6 @@ class DbrxExperts(FusedMoE):
hidden_size=config.d_model, hidden_size=config.d_model,
intermediate_size=config.ffn_config.ffn_hidden_size, intermediate_size=config.ffn_config.ffn_hidden_size,
params_dtype=params_dtype, params_dtype=params_dtype,
reduce_results=True,
renormalize=True, renormalize=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=get_tensor_model_parallel_world_size(), tp_size=get_tensor_model_parallel_world_size(),
......
...@@ -318,7 +318,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -318,7 +318,6 @@ class DeepseekV2MoE(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
use_grouped_topk=True, use_grouped_topk=True,
...@@ -326,11 +325,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -326,11 +325,9 @@ class DeepseekV2MoE(nn.Module):
topk_group=getattr(config, "topk_group", 1), topk_group=getattr(config, "topk_group", 1),
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
scoring_func=getattr(config, "scoring_func", "softmax"), scoring_func=getattr(config, "scoring_func", "softmax"),
# we do scaling outside, set factor to 1.0 to avoid double mul
# aiter applies routed_scaling_factor internally # aiter applies routed_scaling_factor internally
routed_scaling_factor=1.0 routed_scaling_factor=self.routed_scaling_factor,
if not self.is_rocm_aiter_moe_enabled apply_routed_scale_to_output=not self.is_rocm_aiter_moe_enabled,
else self.routed_scaling_factor,
e_score_correction_bias=self.gate.e_score_correction_bias, e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
...@@ -363,43 +360,20 @@ class DeepseekV2MoE(nn.Module): ...@@ -363,43 +360,20 @@ class DeepseekV2MoE(nn.Module):
hidden_states = sequence_parallel_chunk(hidden_states) hidden_states = sequence_parallel_chunk(hidden_states)
if self.experts.is_internal_router: if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class final_hidden_states = self.experts(
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states hidden_states=hidden_states, router_logits=hidden_states
) )
else: else:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
shared_output, final_hidden_states = fused_moe_out
if self.shared_experts is None:
assert shared_output is None
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
if not self.is_rocm_aiter_moe_enabled:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel: if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0 final_hidden_states, 0
) )
final_hidden_states = final_hidden_states[:num_tokens] final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
......
...@@ -37,7 +37,6 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig ...@@ -37,7 +37,6 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import ( from vllm.distributed import (
get_pp_group, get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
...@@ -120,7 +119,6 @@ class Dots1MoE(nn.Module): ...@@ -120,7 +119,6 @@ class Dots1MoE(nn.Module):
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts self.n_shared_experts = config.n_shared_experts
...@@ -163,7 +161,6 @@ class Dots1MoE(nn.Module): ...@@ -163,7 +161,6 @@ class Dots1MoE(nn.Module):
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
use_grouped_topk=True, use_grouped_topk=True,
...@@ -171,9 +168,9 @@ class Dots1MoE(nn.Module): ...@@ -171,9 +168,9 @@ class Dots1MoE(nn.Module):
topk_group=config.topk_group, topk_group=config.topk_group,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
scoring_func=config.scoring_func, scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias, e_score_correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scale_to_output=True,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -182,16 +179,9 @@ class Dots1MoE(nn.Module): ...@@ -182,16 +179,9 @@ class Dots1MoE(nn.Module):
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
shared_out, routed_out = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
if self.shared_experts is not None:
final_hidden_states = (routed_out + shared_out) * self.routed_scaling_factor
else:
final_hidden_states = routed_out * self.routed_scaling_factor
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment