Unverified Commit e50c4546 authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

[BugFix] Support EP/DP + EPLB with MTP (#25311)


Signed-off-by: default avatarilmarkov <markovilya197@gmail.com>
Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
Co-authored-by: default avatarSage Moore <sage@neuralmagic.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
parent 5d16d0fa
...@@ -232,8 +232,8 @@ steps: ...@@ -232,8 +232,8 @@ steps:
commands: commands:
- pytest -v -s distributed/test_eplb_algo.py - pytest -v -s distributed/test_eplb_algo.py
- label: EPLB Execution Test # 5min - label: EPLB Execution Test # 10min
timeout_in_minutes: 15 timeout_in_minutes: 20
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 4 num_gpus: 4
source_file_dependencies: source_file_dependencies:
...@@ -241,6 +241,7 @@ steps: ...@@ -241,6 +241,7 @@ steps:
- tests/distributed/test_eplb_execute.py - tests/distributed/test_eplb_execute.py
commands: commands:
- pytest -v -s distributed/test_eplb_execute.py - pytest -v -s distributed/test_eplb_execute.py
- pytest -v -s distributed/test_eplb_spec_decode.py
- label: Metrics, Tracing Test # 12min - label: Metrics, Tracing Test # 12min
timeout_in_minutes: 20 timeout_in_minutes: 20
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import lm_eval
import pytest
from tests.utils import large_gpu_mark
def get_model_args(
model_name: str,
spec_model_name: str,
spec_method: str,
tp_size: int,
model_max_len: int,
) -> dict:
speculative_config = {
"method": spec_method,
"model": spec_model_name,
"num_speculative_tokens": 1,
"max_model_len": model_max_len,
}
model_args = {
"pretrained": model_name,
"dtype": "auto",
"add_bos_token": True,
"tensor_parallel_size": tp_size,
"gpu_memory_utilization": 0.7,
"speculative_config": speculative_config,
"enable_expert_parallel": True,
"num_redundant_experts": tp_size,
"eplb_window_size": 128,
"eplb_step_interval": 1024,
"eplb_log_balancedness": False,
"enable_eplb": True,
"max_model_len": model_max_len,
}
return model_args
@pytest.mark.parametrize(
"model_setup",
[
pytest.param(
("mtp", "Qwen/Qwen3-Next-80B-A3B-Instruct", None, 4, 0.86),
marks=large_gpu_mark(min_gb=80),
),
pytest.param(
(
"eagle",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
4,
0.92,
),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues"),
),
],
ids=["qwen3_next_mtp", "llama4_eagle"],
)
def test_eplb_spec_decode(
monkeypatch: pytest.MonkeyPatch,
model_setup: tuple[str, str, str, int, float],
):
"""
Test the correctness of EPLB speculative decoding with GSM8K dataset.
Applicable to MoE models with mtp or eagle spec decode.
"""
method, model_name, spec_model_name, tp_size, expected_gsm8k_value = model_setup
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
model_args = get_model_args(
model_name=model_name,
spec_model_name=spec_model_name,
spec_method=method,
tp_size=tp_size,
model_max_len=4096,
)
results = lm_eval.simple_evaluate(
model="vllm",
model_args=model_args,
tasks=TASK,
batch_size=64,
num_fewshot=8,
)
measured_value = results["results"][TASK][FILTER]
assert (
measured_value - RTOL < expected_gsm8k_value
and measured_value + RTOL > expected_gsm8k_value
), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"
...@@ -33,7 +33,7 @@ from dataclasses import dataclass ...@@ -33,7 +33,7 @@ from dataclasses import dataclass
import torch import torch
from torch.distributed import ProcessGroup, all_reduce from torch.distributed import ProcessGroup, all_reduce
from vllm.config import ParallelConfig from vllm.config import ModelConfig, ParallelConfig
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_ep_group, get_ep_group,
get_node_count, get_node_count,
...@@ -50,7 +50,7 @@ logger = init_logger(__name__) ...@@ -50,7 +50,7 @@ logger = init_logger(__name__)
@dataclass @dataclass
class EplbState: class EplbModelState:
"""EPLB metrics.""" """EPLB metrics."""
physical_to_logical_map: torch.Tensor physical_to_logical_map: torch.Tensor
...@@ -130,20 +130,31 @@ class EplbState: ...@@ -130,20 +130,31 @@ class EplbState:
See: See:
https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856 https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
""" """
expert_load_window_step: int = 0 model_name: str
model: MixtureOfExperts
class EplbState:
"""
EplbState of each expert parallel model. Key is the model config hash.
"""
def __init__(self, parallel_config: ParallelConfig, device: torch.device):
self.parallel_config = parallel_config
self.device = device
self.model_states: dict[str, EplbModelState] = {}
""" """
Current step in the sliding window. Current step in the sliding window.
Different from `expert_rearrangement_step`, each EP rank may have its own Different from `expert_rearrangement_step`,
`expert_load_window_step`. each EP rank may have its own `expert_load_window_step`.
""" """
expert_load_window_size: int = 0 self.expert_load_window_step: int = 0
""" """
Size of the expert load sliding window. Size of the expert load sliding window.
This is a constant and is taken from the config. This is a constant and is taken from the config.
""" """
self.expert_load_window_size: int = 0
expert_rearrangement_step: int = 0
""" """
Steps after last rearrangement. Steps after last rearrangement.
Will trigger a rearrangement if it exceeds the threshold. Will trigger a rearrangement if it exceeds the threshold.
...@@ -153,11 +164,12 @@ class EplbState: ...@@ -153,11 +164,12 @@ class EplbState:
Otherwise, the rearrangement will hang at collective Otherwise, the rearrangement will hang at collective
communication calls. communication calls.
""" """
expert_rearrangement_step_interval: int = 0 self.expert_rearrangement_step: int = 0
""" """
Interval for expert rearrangement steps. Interval for expert rearrangement steps.
This is a constant and is taken from the config. This is a constant and is taken from the config.
""" """
self.expert_rearrangement_step_interval: int = 0
@staticmethod @staticmethod
def build_initial_global_physical_to_logical_map( def build_initial_global_physical_to_logical_map(
...@@ -179,26 +191,63 @@ class EplbState: ...@@ -179,26 +191,63 @@ class EplbState:
] ]
return global_physical_to_logical_map return global_physical_to_logical_map
@classmethod def validate_ep_configuration(self, new_model: MixtureOfExperts):
def build( """
cls, Validate that the expert parallel configuration of
the new model is the same as the existing models.
"""
if len(self.model_states) > 0:
model = next(iter(self.model_states.values())).model
if (
model.num_routed_experts != new_model.num_routed_experts
or model.num_redundant_experts != new_model.num_redundant_experts
or model.num_physical_experts != new_model.num_physical_experts
or model.num_logical_experts != new_model.num_logical_experts
or model.num_expert_groups != new_model.num_expert_groups
):
raise RuntimeError(
"Model: {} "
"with config {} "
"{} {} {} {} "
"mismatch with new model {} "
"with config {} "
"{} {} {} {}".format(
type(model),
model.num_routed_experts,
model.num_redundant_experts,
model.num_physical_experts,
model.num_logical_experts,
model.num_expert_groups,
type(new_model),
new_model.num_routed_experts,
new_model.num_redundant_experts,
new_model.num_physical_experts,
new_model.num_logical_experts,
new_model.num_expert_groups,
)
)
def add_model(
self,
model: MixtureOfExperts, model: MixtureOfExperts,
device: torch.device, model_config: ModelConfig,
parallel_config: ParallelConfig,
global_expert_load: torch.Tensor | None = None, global_expert_load: torch.Tensor | None = None,
old_global_expert_indices: torch.Tensor | None = None, old_global_expert_indices: torch.Tensor | None = None,
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
) -> "EplbState": ):
""" """
Build the initial EPLB state. Build the initial EPLB state.
""" """
physical_to_logical_map_list = cls.build_initial_global_physical_to_logical_map( self.validate_ep_configuration(model)
physical_to_logical_map_list = (
EplbState.build_initial_global_physical_to_logical_map(
model.num_routed_experts, model.num_routed_experts,
model.num_redundant_experts, model.num_redundant_experts,
) )
)
physical_to_logical_map = torch.tensor( physical_to_logical_map = torch.tensor(
physical_to_logical_map_list, physical_to_logical_map_list,
device=device, device=self.device,
) )
# Assuming 8 GPUs per node, this supports up to # Assuming 8 GPUs per node, this supports up to
# (1023 + 1) / 8 = 128 nodes for now. # (1023 + 1) / 8 = 128 nodes for now.
...@@ -212,11 +261,11 @@ class EplbState: ...@@ -212,11 +261,11 @@ class EplbState:
logical_to_physical_map = torch.full( logical_to_physical_map = torch.full(
(model.num_logical_experts, max_slots_per_logical_expert), (model.num_logical_experts, max_slots_per_logical_expert),
-1, -1,
device=device, device=self.device,
) )
logical_replica_count = torch.zeros( logical_replica_count = torch.zeros(
(model.num_logical_experts,), (model.num_logical_experts,),
device=device, device=self.device,
dtype=torch.long, dtype=torch.long,
) )
...@@ -255,18 +304,25 @@ class EplbState: ...@@ -255,18 +304,25 @@ class EplbState:
expert_load_pass = torch.zeros( expert_load_pass = torch.zeros(
(model.num_moe_layers, model.num_physical_experts), (model.num_moe_layers, model.num_physical_experts),
dtype=torch.int32, dtype=torch.int32,
device=device, device=self.device,
) )
expert_load_window_size = parallel_config.eplb_config.window_size self.expert_load_window_size = self.parallel_config.eplb_config.window_size
expert_load_window = torch.zeros( expert_load_window = torch.zeros(
(expert_load_window_size, model.num_moe_layers, model.num_physical_experts), (
self.expert_load_window_size,
model.num_moe_layers,
model.num_physical_experts,
),
dtype=torch.int32, dtype=torch.int32,
device=device, device=self.device,
) )
# Set the initial progress of rearrangement to 3/4 # Set the initial progress of rearrangement to 3/4
eplb_step_interval = parallel_config.eplb_config.step_interval eplb_step_interval = self.parallel_config.eplb_config.step_interval
expert_rearrangement_step = max(0, eplb_step_interval - eplb_step_interval // 4) self.expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4
)
self.expert_rearrangement_step_interval = eplb_step_interval
if global_expert_load is not None: if global_expert_load is not None:
ep_group = get_ep_group().device_group ep_group = get_ep_group().device_group
...@@ -309,7 +365,7 @@ class EplbState: ...@@ -309,7 +365,7 @@ class EplbState:
(0, logical_to_physical_map.shape[-1] - max_physical_slots), (0, logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1, value=-1,
) )
physical_to_logical_map = new_physical_to_logical_map.to(device) physical_to_logical_map = new_physical_to_logical_map.to(self.device)
logical_to_physical_map.copy_(new_logical_to_physical_map) logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count) logical_replica_count.copy_(new_logical_replica_count)
...@@ -327,22 +383,20 @@ class EplbState: ...@@ -327,22 +383,20 @@ class EplbState:
False, False,
rank_mapping, rank_mapping,
) )
expert_rearrangement_step = 0 self.expert_rearrangement_step = 0
return cls( self.model_states[model_config.compute_hash()] = EplbModelState(
physical_to_logical_map, physical_to_logical_map,
logical_to_physical_map, logical_to_physical_map,
logical_replica_count, logical_replica_count,
expert_load_pass, expert_load_pass,
expert_load_window, expert_load_window,
expert_load_window_size=expert_load_window_size, model_config.model,
expert_rearrangement_step=expert_rearrangement_step, model,
expert_rearrangement_step_interval=eplb_step_interval,
) )
def step( def step(
self, self,
model: MixtureOfExperts,
is_dummy: bool = False, is_dummy: bool = False,
is_profile: bool = False, is_profile: bool = False,
log_stats: bool = False, log_stats: bool = False,
...@@ -351,7 +405,6 @@ class EplbState: ...@@ -351,7 +405,6 @@ class EplbState:
Step the EPLB state. Step the EPLB state.
Args: Args:
model (MixtureOfExperts): The MoE model.
is_dummy (bool): If `True`, this is a dummy step and the load is_dummy (bool): If `True`, this is a dummy step and the load
metrics recorded in this forward pass will not count. metrics recorded in this forward pass will not count.
Defaults to `False`. Defaults to `False`.
...@@ -369,25 +422,26 @@ class EplbState: ...@@ -369,25 +422,26 @@ class EplbState:
""" """
if is_profile: if is_profile:
self.rearrange(model, is_profile=True) self.rearrange(is_profile=True)
return return
if is_dummy: if is_dummy:
# Do not record load metrics for dummy steps # Do not record load metrics for dummy steps
self.expert_load_pass.zero_() for eplb_model_state in self.model_states.values():
eplb_model_state.expert_load_pass.zero_()
if log_stats: if log_stats:
# total_expert_load_pass: (num_moe_layers, num_physical_experts) # Sync the expert load pass for each model (main and drafter).
total_expert_load_pass = self.expert_load_pass.clone() # expert_load_pass: (num_moe_layers, num_physical_experts)
expert_load_pass_list = self._sync_load_pass()
# Collect load metrics from all ranks
ep_group = get_ep_group().device_group ep_group = get_ep_group().device_group
all_reduce(total_expert_load_pass, group=ep_group) for expert_load_pass, eplb_model_state in zip(
expert_load_pass_list, self.model_states.values()
):
# num_tokens_per_rank: (num_moe_layers, num_ranks) # num_tokens_per_rank: (num_moe_layers, num_ranks)
num_tokens_per_rank = ( num_tokens_per_rank = (
total_expert_load_pass.reshape( expert_load_pass.reshape(
total_expert_load_pass.shape[0], ep_group.size(), -1 expert_load_pass.shape[0], ep_group.size(), -1
) )
.sum(dim=-1) .sum(dim=-1)
.float() .float()
...@@ -408,7 +462,10 @@ class EplbState: ...@@ -408,7 +462,10 @@ class EplbState:
if ep_group.rank() == 0: if ep_group.rank() == 0:
logger.info( logger.info(
"EPLB step: avg_tokens=%.2f, max_tokens=%d, balancedness=%.4f", "EPLB step: %d for model %s: avg_tokens=%.2f, "
"max_tokens=%d, balancedness=%.4f",
self.expert_rearrangement_step,
eplb_model_state.model_name,
avg_tokens, avg_tokens,
max_tokens, max_tokens,
balancedness, balancedness,
...@@ -416,13 +473,15 @@ class EplbState: ...@@ -416,13 +473,15 @@ class EplbState:
# Update the expert load sliding window # Update the expert load sliding window
if not is_dummy: if not is_dummy:
self.expert_load_window[self.expert_load_window_step] = ( for eplb_model_state in self.model_states.values():
self.expert_load_pass.clone() eplb_model_state.expert_load_window[self.expert_load_window_step] = (
eplb_model_state.expert_load_pass.clone()
) )
eplb_model_state.expert_load_pass.zero_()
self.expert_load_window_step += 1 self.expert_load_window_step += 1
if self.expert_load_window_step >= self.expert_load_window_size: if self.expert_load_window_step >= self.expert_load_window_size:
self.expert_load_window_step = 0 self.expert_load_window_step = 0
self.expert_load_pass.zero_()
# Step the expert rearrangement step # Step the expert rearrangement step
# Note that even if this is a dummy step, we still increment the # Note that even if this is a dummy step, we still increment the
...@@ -431,18 +490,30 @@ class EplbState: ...@@ -431,18 +490,30 @@ class EplbState:
self.expert_rearrangement_step += 1 self.expert_rearrangement_step += 1
if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval: if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
self.expert_rearrangement_step = 0 self.expert_rearrangement_step = 0
self.rearrange(model) self.rearrange()
def rearrange( def rearrange(
self, self,
model: MixtureOfExperts,
is_profile: bool = False, is_profile: bool = False,
execute_shuffle: bool = True, execute_shuffle: bool = True,
global_expert_load: torch.Tensor | None = None, global_expert_loads: list[torch.Tensor] | None = None,
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
""" """
Rearrange the experts according to the current load. Rearrange the experts according to the current load.
Args:
is_profile (bool): If `True`, perform a dummy rearrangement.
This is used in `profile_run` to reserve enough memory,
no memory movement will be performed. Default is False.
execute_shuffle (bool): If `True`, execute the shuffle
in elastic expert parallel (EEP). Default is True.
global_expert_loads (list[torch.Tensor] | None): The global expert
loads when scaling is done in EEP.
List of expert loads for the main and drafter
(when spec decode is used) models.
rank_mapping (dict[int, int] | None): The rank mapping
when scaling is done in EEP.
""" """
ep_group = get_ep_group().device_group ep_group = get_ep_group().device_group
...@@ -455,29 +526,39 @@ class EplbState: ...@@ -455,29 +526,39 @@ class EplbState:
time_start = time.perf_counter() time_start = time.perf_counter()
logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") logger.info("Rearranging experts %s...", "(profile)" if is_profile else "")
if global_expert_load is None: if global_expert_loads is None:
# Map the physical expert load to global logical experts # Map the physical expert load to global logical experts
global_expert_load_windows = []
if not execute_shuffle:
num_models = torch.tensor(
[len(self.model_states)], dtype=torch.int32, device="cpu"
)
torch.distributed.broadcast(
num_models, group=get_ep_group().cpu_group, group_src=0
)
for eplb_model_state in self.model_states.values():
logical_expert_load_window = torch.zeros( logical_expert_load_window = torch.zeros(
self.expert_load_window_size, self.expert_load_window_size,
model.num_moe_layers, eplb_model_state.model.num_moe_layers,
model.num_logical_experts, eplb_model_state.model.num_logical_experts,
dtype=self.expert_load_window.dtype, dtype=eplb_model_state.expert_load_window.dtype,
device=self.expert_load_window.device, device=eplb_model_state.expert_load_window.device,
) )
logical_expert_load_window.scatter_add_( logical_expert_load_window.scatter_add_(
dim=-1, dim=-1,
index=self.physical_to_logical_map.unsqueeze(0) index=eplb_model_state.physical_to_logical_map.unsqueeze(0)
.expand_as(self.expert_load_window) .expand_as(eplb_model_state.expert_load_window)
.long(), .long(),
src=self.expert_load_window, src=eplb_model_state.expert_load_window,
) )
if not execute_shuffle: if not execute_shuffle:
metadata = torch.tensor( metadata = torch.tensor(
[ [
model.num_moe_layers, eplb_model_state.model.num_moe_layers,
model.num_logical_experts, eplb_model_state.model.num_logical_experts,
self.physical_to_logical_map.shape[1], eplb_model_state.physical_to_logical_map.shape[1],
], ],
dtype=torch.int32, dtype=torch.int32,
device="cpu", device="cpu",
...@@ -486,22 +567,30 @@ class EplbState: ...@@ -486,22 +567,30 @@ class EplbState:
metadata, group=get_ep_group().cpu_group, group_src=0 metadata, group=get_ep_group().cpu_group, group_src=0
) )
# Perform all-reduce to get the expert load across all ranks
global_expert_load_window = logical_expert_load_window.sum(dim=0) global_expert_load_window = logical_expert_load_window.sum(dim=0)
all_reduce(global_expert_load_window, group=ep_group) global_expert_load_windows.append(global_expert_load_window)
# Perform all-reduce to get the expert load across all ranks for each model
global_expert_load_windows = self._allreduce_list(
global_expert_load_windows
)
if not execute_shuffle: if not execute_shuffle:
for eplb_model_state, global_expert_load_window in zip(
self.model_states.values(), global_expert_load_windows
):
# (num_moe_layers, old_num_physical_experts) # (num_moe_layers, old_num_physical_experts)
old_global_expert_indices = self.physical_to_logical_map old_global_expert_indices = eplb_model_state.physical_to_logical_map
torch.distributed.broadcast( torch.distributed.broadcast(
old_global_expert_indices, group=ep_group, group_src=0 old_global_expert_indices, group=ep_group, group_src=0
) )
return global_expert_load_window if not execute_shuffle:
return global_expert_load_windows
else: else:
assert execute_shuffle assert execute_shuffle
global_expert_load_window = global_expert_load global_expert_load_windows = global_expert_loads
# TODO(bowen): Treat differently for prefill and decode nodes # TODO(bowen): Treat differently for prefill and decode nodes
eplb_model_state = next(iter(self.model_states.values()))
model = eplb_model_state.model
num_replicas = model.num_physical_experts num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups num_groups = model.num_expert_groups
if rank_mapping is not None and len(rank_mapping) == ep_group.size(): if rank_mapping is not None and len(rank_mapping) == ep_group.size():
...@@ -526,7 +615,10 @@ class EplbState: ...@@ -526,7 +615,10 @@ class EplbState:
f"{num_gpus=}, {num_nodes=}" f"{num_gpus=}, {num_nodes=}"
) )
# Get new expert mappings for eplb_model_state, global_expert_load_window in zip(
self.model_states.values(), global_expert_load_windows
):
# Get new expert mappings for the model
( (
new_physical_to_logical_map, new_physical_to_logical_map,
new_logical_to_physical_map, new_logical_to_physical_map,
...@@ -541,9 +633,9 @@ class EplbState: ...@@ -541,9 +633,9 @@ class EplbState:
# Update expert weights # Update expert weights
rearrange_expert_weights_inplace( rearrange_expert_weights_inplace(
self.physical_to_logical_map, eplb_model_state.physical_to_logical_map,
new_physical_to_logical_map, new_physical_to_logical_map,
model.expert_weights, eplb_model_state.model.expert_weights,
ep_group, ep_group,
is_profile, is_profile,
rank_mapping, rank_mapping,
...@@ -551,23 +643,36 @@ class EplbState: ...@@ -551,23 +643,36 @@ class EplbState:
if not is_profile: if not is_profile:
if ( if (
self.physical_to_logical_map.shape[1] eplb_model_state.physical_to_logical_map.shape[1]
!= new_physical_to_logical_map.shape[1] != new_physical_to_logical_map.shape[1]
): ):
self.physical_to_logical_map = new_physical_to_logical_map.to( eplb_model_state.physical_to_logical_map = (
self.physical_to_logical_map.device new_physical_to_logical_map.to(
eplb_model_state.physical_to_logical_map.device
)
) )
else: else:
self.physical_to_logical_map.copy_(new_physical_to_logical_map) eplb_model_state.physical_to_logical_map.copy_(
new_physical_to_logical_map
)
max_physical_slots = new_logical_to_physical_map.shape[-1] max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= self.logical_to_physical_map.shape[-1] assert (
max_physical_slots
<= eplb_model_state.logical_to_physical_map.shape[-1]
)
new_logical_to_physical_map = torch.nn.functional.pad( new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map, new_logical_to_physical_map,
(0, self.logical_to_physical_map.shape[-1] - max_physical_slots), (
0,
eplb_model_state.logical_to_physical_map.shape[-1]
- max_physical_slots,
),
value=-1, value=-1,
) )
self.logical_to_physical_map.copy_(new_logical_to_physical_map) eplb_model_state.logical_to_physical_map.copy_(
self.logical_replica_count.copy_(new_logical_replica_count) new_logical_to_physical_map
)
eplb_model_state.logical_replica_count.copy_(new_logical_replica_count)
if is_main_rank: if is_main_rank:
assert time_start is not None assert time_start is not None
...@@ -581,11 +686,17 @@ class EplbState: ...@@ -581,11 +686,17 @@ class EplbState:
return None return None
@staticmethod @staticmethod
def recv_state() -> tuple[torch.Tensor, torch.Tensor]: def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
""" """
Receive the expert load and old placement from the master rank. Receive the expert load and old placement from the master rank.
""" """
ep_group = get_ep_group() ep_group = get_ep_group()
num_models = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0)
num_models = num_models.item()
global_expert_loads = []
old_global_expert_indices_per_model = []
for _ in range(num_models):
metadata = torch.empty(3, dtype=torch.int32, device="cpu") metadata = torch.empty(3, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0) torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
num_moe_layers, num_logical_experts, num_old_physical_experts = ( num_moe_layers, num_logical_experts, num_old_physical_experts = (
...@@ -603,10 +714,90 @@ class EplbState: ...@@ -603,10 +714,90 @@ class EplbState:
device=ep_group.device, device=ep_group.device,
) )
torch.distributed.broadcast( torch.distributed.broadcast(
old_global_expert_indices, group=ep_group.device_group, group_src=0 old_global_expert_indices,
group=ep_group.device_group,
group_src=0,
)
global_expert_loads.append(global_expert_load)
old_global_expert_indices_per_model.append(old_global_expert_indices)
return global_expert_loads, old_global_expert_indices_per_model
@classmethod
def get_eep_state(
cls, parallel_config: ParallelConfig
) -> tuple[
list[torch.Tensor] | None,
list[torch.Tensor] | None,
dict[int, int] | None,
]:
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(
num_local_physical_experts,
group=get_ep_group().cpu_group,
group_src=0,
)
num_local_physical_experts = int(num_local_physical_experts.item())
new_ep_size = get_ep_group().world_size
global_expert_loads, old_global_expert_indices_per_model = (
EplbState.recv_state()
) )
return global_expert_load, old_global_expert_indices # EP configuration for all models has to be the same so as eplb config
num_logical_experts = global_expert_loads[0].shape[1]
parallel_config.eplb_config.num_redundant_experts = (
num_local_physical_experts * new_ep_size - num_logical_experts
)
assert (
old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts
== 0
)
old_ep_size = (
old_global_expert_indices_per_model[0].shape[1]
// num_local_physical_experts
)
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
return (
global_expert_loads,
old_global_expert_indices_per_model,
rank_mapping,
)
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
"""
All-reduce a list of tensors.
"""
if len(tensor_list) == 1:
all_reduce(tensor_list[0], group=get_ep_group().device_group)
return tensor_list
assert all(t.dim() == 2 for t in tensor_list), "All tensors must be 2D."
assert all(t.shape[1] == tensor_list[0].shape[1] for t in tensor_list), (
"All tensors must have the same shape[1]."
)
# Concatenate, all_reduce, then unpack to original shapes.
# We assume all tensors are 2D and shape[1] (num_physical_experts)
# is the same across all models.
shapes = [t.shape for t in tensor_list]
concat_tensor = torch.cat(tensor_list, dim=0)
ep_group = get_ep_group().device_group
all_reduce(concat_tensor, group=ep_group)
all_reduce_list = []
offset = 0
for shape in shapes:
all_reduce_list.append(concat_tensor[offset : offset + shape[0], :])
offset += shape[0]
return all_reduce_list
def _sync_load_pass(self) -> list[torch.Tensor]:
"""
Sync the expert load pass across all ranks for log stats.
Doesn't update the expert load pass in eplb_model_state.
"""
load_pass_list = []
for eplb_model_state in self.model_states.values():
load_pass_list.append(eplb_model_state.expert_load_pass.clone())
return self._allreduce_list(load_pass_list)
def _node_count_with_rank_mapping( def _node_count_with_rank_mapping(
......
...@@ -226,7 +226,7 @@ class ToolParserManager: ...@@ -226,7 +226,7 @@ class ToolParserManager:
if isinstance(name, str): if isinstance(name, str):
names = [name] names = [name]
elif is_list_of(name, str): elif name is not None and is_list_of(name, str):
names = name names = name
else: else:
names = [class_name] names = [class_name]
......
...@@ -24,9 +24,12 @@ from vllm.model_executor.models.deepseek_v2 import ( ...@@ -24,9 +24,12 @@ from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2DecoderLayer, DeepseekV2DecoderLayer,
DeepseekV3ForCausalLM, DeepseekV3ForCausalLM,
) )
from vllm.utils import init_logger
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, maybe_prefix
logger = init_logger(__name__)
@support_torch_compile @support_torch_compile
class DeepseekV2Model(nn.Module): class DeepseekV2Model(nn.Module):
...@@ -215,6 +218,10 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): ...@@ -215,6 +218,10 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
self.config.vocab_size, scale=logit_scale self.config.vocab_size, scale=logit_scale
) )
# Set MoE hyperparameters
self.num_moe_layers = self.config.num_hidden_layers
self.set_moe_parameters()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -8,6 +8,7 @@ from transformers import PretrainedConfig ...@@ -8,6 +8,7 @@ from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -25,11 +26,15 @@ from vllm.sequence import IntermediateTensors ...@@ -25,11 +26,15 @@ from vllm.sequence import IntermediateTensors
from .deepseek_v2 import ( from .deepseek_v2 import (
DeepseekV2DecoderLayer, DeepseekV2DecoderLayer,
DeepseekV2MixtureOfExperts,
DeepseekV2MoE,
get_spec_layer_idx_from_weight_name, get_spec_layer_idx_from_weight_name,
) )
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import maybe_prefix from .utils import maybe_prefix
logger = init_logger(__name__)
class SharedHead(nn.Module): class SharedHead(nn.Module):
def __init__( def __init__(
...@@ -119,6 +124,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -119,6 +124,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
self.mtp_start_layer_idx = config.num_hidden_layers self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights # to map the exact layer index from weights
self.layers = torch.nn.ModuleDict( self.layers = torch.nn.ModuleDict(
{ {
str(idx): DeepSeekMultiTokenPredictorLayer( str(idx): DeepSeekMultiTokenPredictorLayer(
...@@ -172,13 +178,33 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -172,13 +178,33 @@ class DeepSeekMultiTokenPredictor(nn.Module):
@support_torch_compile @support_torch_compile
class DeepSeekMTP(nn.Module, SupportsPP): class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.model = DeepSeekMultiTokenPredictor( self.model = DeepSeekMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
# Set MoE hyperparameters
self.set_moe_parameters()
def set_moe_parameters(self):
self.expert_weights = []
self.num_moe_layers = self.config.num_nextn_predict_layers
self.num_expert_groups = self.config.n_group
self.moe_layers = []
self.moe_mlp_layers = []
example_moe = None
for layer in self.model.layers.values():
assert isinstance(layer, DeepSeekMultiTokenPredictorLayer)
layer = layer.mtp_block
assert isinstance(layer, DeepseekV2DecoderLayer)
if isinstance(layer.mlp, DeepseekV2MoE):
example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts)
self.extract_moe_parameters(example_moe)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -166,7 +166,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -166,7 +166,7 @@ class DeepseekV2MoE(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts self.n_shared_experts: int = config.n_shared_experts
...@@ -1122,7 +1122,6 @@ class DeepseekV2Model(nn.Module): ...@@ -1122,7 +1122,6 @@ class DeepseekV2Model(nn.Module):
) )
else: else:
self.embed_tokens = PPMissingLayer() self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer( lambda prefix: DeepseekV2DecoderLayer(
...@@ -1172,7 +1171,50 @@ class DeepseekV2Model(nn.Module): ...@@ -1172,7 +1171,50 @@ class DeepseekV2Model(nn.Module):
return hidden_states return hidden_states
class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoRA): class DeepseekV2MixtureOfExperts(MixtureOfExperts):
moe_mlp_layers: list[DeepseekV2MoE]
"""
List of MoE MLP layers in the model.
"""
def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None):
if example_moe is None:
self.num_moe_layers = 0
self.num_expert_groups = 0
self.num_logical_experts = 0
self.num_physical_experts = 0
self.num_local_physical_experts = 0
self.num_routed_experts = 0
self.num_shared_experts = 0
self.num_redundant_experts = 0
logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.")
else:
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for moe in self.moe_mlp_layers:
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
class DeepseekV2ForCausalLM(
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA
):
packed_modules_mapping = { packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
} }
...@@ -1213,13 +1255,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR ...@@ -1213,13 +1255,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
# Set MoE hyperparameters
self.num_moe_layers = (
self.config.num_hidden_layers - self.config.first_k_dense_replace
)
self.set_moe_parameters()
def set_moe_parameters(self):
self.expert_weights = [] self.expert_weights = []
# Set MoE hyperparameters self.num_expert_groups = self.config.n_group
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_expert_groups = config.n_group
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
self.moe_mlp_layers = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
...@@ -1229,50 +1277,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR ...@@ -1229,50 +1277,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
if isinstance(layer.mlp, DeepseekV2MoE): if isinstance(layer.mlp, DeepseekV2MoE):
# Pick last one layer since the first ones may be dense layers. # Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts) self.moe_layers.append(layer.mlp.experts)
if example_moe is None: self.extract_moe_parameters(example_moe)
raise RuntimeError("No DeepseekV2MoE layer found in model.layers.")
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.model.layers:
if isinstance(layer.mlp, DeepseekV2MoE):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -133,7 +133,7 @@ class Ernie4_5_MoeMoE(nn.Module): ...@@ -133,7 +133,7 @@ class Ernie4_5_MoeMoE(nn.Module):
self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", None) self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", None)
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.moe_num_experts self.n_routed_experts: int = config.moe_num_experts
self.n_shared_experts: int = self.moe_num_shared_experts self.n_shared_experts: int = self.moe_num_shared_experts
...@@ -709,22 +709,6 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe ...@@ -709,22 +709,6 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe
self.num_shared_experts = example_moe.n_shared_experts self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,
......
...@@ -62,7 +62,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -62,7 +62,7 @@ from vllm.model_executor.model_loader.weight_utils import (
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
...@@ -127,7 +127,7 @@ class Glm4MoE(nn.Module): ...@@ -127,7 +127,7 @@ class Glm4MoE(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts self.n_shared_experts: int = config.n_shared_experts
...@@ -616,7 +616,35 @@ class Glm4MoeModel(nn.Module): ...@@ -616,7 +616,35 @@ class Glm4MoeModel(nn.Module):
return loaded_params return loaded_params
class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): class Glm4MixtureOfExperts(MixtureOfExperts):
def extract_moe_parameters(self, example_moe: Glm4MoE | None) -> None:
if example_moe is None:
raise RuntimeError("No Glm4MoE layer found in model.layers.")
else:
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for moe in self.moe_mlp_layers:
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, Glm4MixtureOfExperts):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -659,7 +687,9 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -659,7 +687,9 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_expert_groups = config.n_group self.num_expert_groups = config.n_group
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
self.moe_mlp_layers: list[Glm4MoE] = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
...@@ -669,33 +699,10 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -669,33 +699,10 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
if isinstance(layer.mlp, Glm4MoE): if isinstance(layer.mlp, Glm4MoE):
# Pick last one layer since the first ones may be dense layers. # Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts) self.moe_layers.append(layer.mlp.experts)
if example_moe is None: self.extract_moe_parameters(example_moe)
raise RuntimeError("No Glm4MoE layer found in model.layers.")
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -29,7 +29,7 @@ import torch ...@@ -29,7 +29,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -41,7 +41,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -41,7 +41,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name from .glm4_moe import (
Glm4MixtureOfExperts,
Glm4MoE,
Glm4MoeDecoderLayer,
get_spec_layer_idx_from_weight_name,
)
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import maybe_prefix from .utils import maybe_prefix
...@@ -73,6 +78,7 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module): ...@@ -73,6 +78,7 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module):
prefix: str, prefix: str,
cache_config: CacheConfig | None = None, cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
parallel_config: ParallelConfig | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -81,11 +87,13 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module): ...@@ -81,11 +87,13 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module):
self.shared_head = SharedHead( self.shared_head = SharedHead(
config=config, prefix=prefix, quant_config=quant_config config=config, prefix=prefix, quant_config=quant_config
) )
self.enable_eplb = parallel_config.enable_eplb
self.mtp_block = Glm4MoeDecoderLayer( self.mtp_block = Glm4MoeDecoderLayer(
config=config, config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
enable_eplb=self.enable_eplb,
) )
def forward( def forward(
...@@ -127,6 +135,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module): ...@@ -127,6 +135,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
f"{prefix}.layers.{idx}", f"{prefix}.layers.{idx}",
cache_config=vllm_config.cache_config, cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config, quant_config=vllm_config.quant_config,
parallel_config=vllm_config.parallel_config,
) )
for idx in range( for idx in range(
self.mtp_start_layer_idx, self.mtp_start_layer_idx,
...@@ -175,7 +184,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module): ...@@ -175,7 +184,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
return logits return logits
class Glm4MoeMTP(nn.Module, SupportsPP): class Glm4MoeMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
...@@ -183,6 +192,25 @@ class Glm4MoeMTP(nn.Module, SupportsPP): ...@@ -183,6 +192,25 @@ class Glm4MoeMTP(nn.Module, SupportsPP):
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
self.expert_weights = []
# Set MoE hyperparameters
self.num_moe_layers = self.config.num_nextn_predict_layers
self.num_expert_groups = self.config.n_group
self.moe_layers: list[FusedMoE] = []
self.moe_mlp_layers: list[Glm4MoE] = []
example_moe = None
for layer in self.model.layers.values():
assert isinstance(layer, Glm4MoeMultiTokenPredictorLayer)
layer = layer.mtp_block
assert isinstance(layer, Glm4MoeDecoderLayer)
if isinstance(layer.mlp, Glm4MoE):
example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts)
self.extract_moe_parameters(example_moe)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -374,7 +374,7 @@ class HunYuanSparseMoeBlock(nn.Module): ...@@ -374,7 +374,7 @@ class HunYuanSparseMoeBlock(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
...@@ -1007,7 +1007,7 @@ class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts): ...@@ -1007,7 +1007,7 @@ class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):
# Set MoE hyperparameters # Set MoE hyperparameters
self.expert_weights = [] self.expert_weights = []
self.num_expert_groups = 1 self.num_expert_groups = 1
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
example_layer = None example_layer = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
...@@ -1028,22 +1028,6 @@ class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts): ...@@ -1028,22 +1028,6 @@ class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):
self.num_routed_experts = example_layer.n_routed_experts self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts self.num_redundant_experts = example_layer.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
self.expert_weights.append(layer.get_expert_weights())
# Register the expert weights.
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,
......
...@@ -14,6 +14,7 @@ from typing import ( ...@@ -14,6 +14,7 @@ from typing import (
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from torch import Tensor from torch import Tensor
from transformers import PretrainedConfig from transformers import PretrainedConfig
from transformers.models.whisper.tokenization_whisper import LANGUAGES from transformers.models.whisper.tokenization_whisper import LANGUAGES
...@@ -641,6 +642,9 @@ class MixtureOfExperts(Protocol): ...@@ -641,6 +642,9 @@ class MixtureOfExperts(Protocol):
num_redundant_experts: int num_redundant_experts: int
"""Number of redundant experts in this model.""" """Number of redundant experts in this model."""
moe_layers: Iterable[nn.Module]
"""List of MoE layers in this model."""
def set_eplb_state( def set_eplb_state(
self, self,
expert_load_view: Tensor, expert_load_view: Tensor,
...@@ -663,7 +667,15 @@ class MixtureOfExperts(Protocol): ...@@ -663,7 +667,15 @@ class MixtureOfExperts(Protocol):
logical_to_physical_map: Mapping from logical to physical experts. logical_to_physical_map: Mapping from logical to physical experts.
logical_replica_count: Count of replicas for each logical expert. logical_replica_count: Count of replicas for each logical expert.
""" """
... for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
......
...@@ -105,7 +105,7 @@ class Lfm2MoeSparseMoeBlock(nn.Module): ...@@ -105,7 +105,7 @@ class Lfm2MoeSparseMoeBlock(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
...@@ -707,7 +707,7 @@ class Lfm2MoeForCausalLM( ...@@ -707,7 +707,7 @@ class Lfm2MoeForCausalLM(
# Set MoE hyperparameters # Set MoE hyperparameters
self.expert_weights = [] self.expert_weights = []
self.moe_layers: list[FusedMoE] = [] self.moe_layers = []
example_layer = None example_layer = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
...@@ -737,22 +737,6 @@ class Lfm2MoeForCausalLM( ...@@ -737,22 +737,6 @@ class Lfm2MoeForCausalLM(
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,
......
...@@ -30,9 +30,11 @@ from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention ...@@ -30,9 +30,11 @@ from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import ( from vllm.distributed import (
get_ep_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
...@@ -46,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -46,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from vllm.model_executor.models.interfaces import MixtureOfExperts
from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.models.utils import sequence_parallel_chunk
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
...@@ -56,6 +59,8 @@ from .utils import ( ...@@ -56,6 +59,8 @@ from .utils import (
is_pp_missing_parameter, is_pp_missing_parameter,
) )
logger = init_logger(__name__)
class Llama4MoE(nn.Module): class Llama4MoE(nn.Module):
@staticmethod @staticmethod
...@@ -80,6 +85,9 @@ class Llama4MoE(nn.Module): ...@@ -80,6 +85,9 @@ class Llama4MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self.ep_group = get_ep_group().device_group
self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size()
intermediate_size_moe = config.intermediate_size intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear( self.router = ReplicatedLinear(
...@@ -101,6 +109,20 @@ class Llama4MoE(nn.Module): ...@@ -101,6 +109,20 @@ class Llama4MoE(nn.Module):
disable_tp=self.is_sequence_parallel, disable_tp=self.is_sequence_parallel,
) )
# Load balancing settings.
eplb_config = parallel_config.eplb_config if parallel_config else None
self.enable_eplb = parallel_config.enable_eplb if parallel_config else False
self.n_redundant_experts = (
eplb_config.num_redundant_experts if eplb_config else 0
)
self.n_routed_experts: int = config.num_local_experts
self.n_logical_experts = self.n_routed_experts
self.n_shared_experts: int = 1
self.n_local_experts: int = config.num_local_experts
self.n_physical_experts = self.n_local_experts + self.n_redundant_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.experts = SharedFusedMoE( self.experts = SharedFusedMoE(
shared_experts=self.shared_expert, shared_experts=self.shared_expert,
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
...@@ -114,6 +136,8 @@ class Llama4MoE(nn.Module): ...@@ -114,6 +136,8 @@ class Llama4MoE(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
is_sequence_parallel=self.is_sequence_parallel, is_sequence_parallel=self.is_sequence_parallel,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
) )
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -378,6 +402,9 @@ class Llama4Model(LlamaModel): ...@@ -378,6 +402,9 @@ class Llama4Model(LlamaModel):
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer, layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer,
): ):
self.num_experts = vllm_config.model_config.hf_config.num_local_experts self.num_experts = vllm_config.model_config.hf_config.num_local_experts
self.n_redundant_experts = (
vllm_config.parallel_config.eplb_config.num_redundant_experts
)
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
def load_moe_expert_weights( def load_moe_expert_weights(
...@@ -499,7 +526,6 @@ class Llama4Model(LlamaModel): ...@@ -499,7 +526,6 @@ class Llama4Model(LlamaModel):
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id, expert_id=expert_id,
) )
loaded_params.add(full_param_name) loaded_params.add(full_param_name)
expert_param_loaded = True expert_param_loaded = True
...@@ -526,6 +552,7 @@ class Llama4Model(LlamaModel): ...@@ -526,6 +552,7 @@ class Llama4Model(LlamaModel):
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.num_experts, num_experts=self.num_experts,
num_redundant_experts=self.n_redundant_experts,
) )
# Expert parameter mapping for the case where the expert weights are # Expert parameter mapping for the case where the expert weights are
# fused into a single weight tensor. # fused into a single weight tensor.
...@@ -683,7 +710,7 @@ class Llama4Model(LlamaModel): ...@@ -683,7 +710,7 @@ class Llama4Model(LlamaModel):
return loaded_params return loaded_params
class Llama4ForCausalLM(LlamaForCausalLM): class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
...@@ -702,6 +729,57 @@ class Llama4ForCausalLM(LlamaForCausalLM): ...@@ -702,6 +729,57 @@ class Llama4ForCausalLM(LlamaForCausalLM):
super().__init__( super().__init__(
vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer
) )
# Set MoE hyperparameters
self.set_moe_parameters()
def set_moe_parameters(self):
self.expert_weights = []
self.moe_layers = []
example_moe = None
for layer in self.model.layers:
assert isinstance(layer, Llama4DecoderLayer)
if isinstance(layer.feed_forward, Llama4MoE):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.feed_forward
self.moe_layers.append(layer.feed_forward.experts)
if example_moe is None:
self.num_moe_layers = 0
self.num_expert_groups = 0
self.num_logical_experts = 0
self.num_physical_experts = 0
self.num_local_physical_experts = 0
self.num_routed_experts = 0
self.num_shared_experts = 0
self.num_redundant_experts = 0
logger.warning("No Llama4MoE layer found in model.layers.")
else:
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.model.layers:
if isinstance(layer.feed_forward, Llama4MoE):
moe = layer.feed_forward
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def _init_model( def _init_model(
self, self,
......
...@@ -189,6 +189,9 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): ...@@ -189,6 +189,9 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
self.config.vocab_size, scale=logit_scale self.config.vocab_size, scale=logit_scale
) )
# Set MoE hyperparameters
self.set_moe_parameters()
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.model return self.model
......
...@@ -578,6 +578,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): ...@@ -578,6 +578,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
parallel_config = vllm_config.parallel_config
self.prefix = prefix self.prefix = prefix
self.vllm_config = vllm_config self.vllm_config = vllm_config
...@@ -613,6 +614,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): ...@@ -613,6 +614,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
if parallel_config.enable_eplb and getattr(config, "num_experts", 0) > 0:
raise NotImplementedError("EPLB is not supported for MiniCPM yet.")
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
return MiniCPMModel(vllm_config=vllm_config, prefix=prefix) return MiniCPMModel(vllm_config=vllm_config, prefix=prefix)
......
...@@ -98,7 +98,7 @@ class MixtralMoE(nn.Module): ...@@ -98,7 +98,7 @@ class MixtralMoE(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
# Expert Parallelism Load balancing settings. # Expert Parallelism Load balancing settings.
...@@ -546,7 +546,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -546,7 +546,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
) )
self.expert_weights = [] self.expert_weights = []
self.moe_layers: list[FusedMoE] = [] self.moe_layers = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
...@@ -572,22 +572,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -572,22 +572,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
self.num_expert_groups = 1 self.num_expert_groups = 1
self.num_shared_experts = 0 self.num_shared_experts = 0
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,
......
...@@ -65,6 +65,7 @@ from vllm.sequence import IntermediateTensors ...@@ -65,6 +65,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
MixtureOfExperts,
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsEagle3, SupportsEagle3,
SupportsMultiModal, SupportsMultiModal,
...@@ -723,7 +724,7 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): ...@@ -723,7 +724,7 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
dummy_inputs=Mllama4DummyInputsBuilder, dummy_inputs=Mllama4DummyInputsBuilder,
) )
class Llama4ForConditionalGeneration( class Llama4ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3 nn.Module, SupportsMultiModal, SupportsPP, MixtureOfExperts, SupportsEagle3
): ):
merge_by_field_config = True merge_by_field_config = True
...@@ -776,6 +777,17 @@ class Llama4ForConditionalGeneration( ...@@ -776,6 +777,17 @@ class Llama4ForConditionalGeneration(
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
# Set MoE hyperparameters
self.num_expert_groups = 1
self.num_logical_experts = self.language_model.num_logical_experts
self.num_physical_experts = self.language_model.num_physical_experts
self.num_local_physical_experts = self.language_model.num_local_physical_experts
self.num_routed_experts = self.language_model.num_routed_experts
self.num_shared_experts = self.language_model.num_shared_experts
self.num_redundant_experts = self.language_model.num_redundant_experts
self.moe_layers = self.language_model.moe_layers
self.num_moe_layers = len(self.moe_layers)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""Set which layers should output auxiliary hidden states for EAGLE3.""" """Set which layers should output auxiliary hidden states for EAGLE3."""
# Delegate to underlying language model (Llama4ForCausalLM) # Delegate to underlying language model (Llama4ForCausalLM)
...@@ -792,6 +804,24 @@ class Llama4ForConditionalGeneration( ...@@ -792,6 +804,24 @@ class Llama4ForConditionalGeneration(
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers") assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
return self.language_model.get_eagle3_aux_hidden_state_layers() return self.language_model.get_eagle3_aux_hidden_state_layers()
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
):
self.language_model.set_eplb_state(
expert_load_view, logical_to_physical_map, logical_replica_count
)
self.expert_weights = self.language_model.expert_weights
def update_physical_experts_metadata(
self, num_physical_experts: int, num_local_physical_experts: int
):
self.language_model.update_physical_experts_metadata(
num_physical_experts, num_local_physical_experts
)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
) -> Llama4ImagePatchInputs | None: ) -> Llama4ImagePatchInputs | None:
......
...@@ -807,7 +807,7 @@ class NemotronHForCausalLM( ...@@ -807,7 +807,7 @@ class NemotronHForCausalLM(
self.expert_weights = [] self.expert_weights = []
self.num_expert_groups = config.n_group self.num_expert_groups = config.n_group
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, NemotronHMoEDecoderLayer): if isinstance(layer, NemotronHMoEDecoderLayer):
...@@ -824,22 +824,6 @@ class NemotronHForCausalLM( ...@@ -824,22 +824,6 @@ class NemotronHForCausalLM(
self.num_shared_experts = example_moe.n_shared_experts self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,
......
...@@ -1009,7 +1009,7 @@ class OpenPanguMoEModel(OpenPanguModelBase, MixtureOfExperts): ...@@ -1009,7 +1009,7 @@ class OpenPanguMoEModel(OpenPanguModelBase, MixtureOfExperts):
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_expert_groups = 1 self.num_expert_groups = 1
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
...@@ -1031,22 +1031,6 @@ class OpenPanguMoEModel(OpenPanguModelBase, MixtureOfExperts): ...@@ -1031,22 +1031,6 @@ class OpenPanguMoEModel(OpenPanguModelBase, MixtureOfExperts):
self.n_shared_experts = example_moe.n_shared_experts self.n_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment