Unverified Commit e50c4546 authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

[BugFix] Support EP/DP + EPLB with MTP (#25311)


Signed-off-by: default avatarilmarkov <markovilya197@gmail.com>
Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
Co-authored-by: default avatarSage Moore <sage@neuralmagic.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
parent 5d16d0fa
...@@ -232,8 +232,8 @@ steps: ...@@ -232,8 +232,8 @@ steps:
commands: commands:
- pytest -v -s distributed/test_eplb_algo.py - pytest -v -s distributed/test_eplb_algo.py
- label: EPLB Execution Test # 5min - label: EPLB Execution Test # 10min
timeout_in_minutes: 15 timeout_in_minutes: 20
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 4 num_gpus: 4
source_file_dependencies: source_file_dependencies:
...@@ -241,6 +241,7 @@ steps: ...@@ -241,6 +241,7 @@ steps:
- tests/distributed/test_eplb_execute.py - tests/distributed/test_eplb_execute.py
commands: commands:
- pytest -v -s distributed/test_eplb_execute.py - pytest -v -s distributed/test_eplb_execute.py
- pytest -v -s distributed/test_eplb_spec_decode.py
- label: Metrics, Tracing Test # 12min - label: Metrics, Tracing Test # 12min
timeout_in_minutes: 20 timeout_in_minutes: 20
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import lm_eval
import pytest
from tests.utils import large_gpu_mark
def get_model_args(
model_name: str,
spec_model_name: str,
spec_method: str,
tp_size: int,
model_max_len: int,
) -> dict:
speculative_config = {
"method": spec_method,
"model": spec_model_name,
"num_speculative_tokens": 1,
"max_model_len": model_max_len,
}
model_args = {
"pretrained": model_name,
"dtype": "auto",
"add_bos_token": True,
"tensor_parallel_size": tp_size,
"gpu_memory_utilization": 0.7,
"speculative_config": speculative_config,
"enable_expert_parallel": True,
"num_redundant_experts": tp_size,
"eplb_window_size": 128,
"eplb_step_interval": 1024,
"eplb_log_balancedness": False,
"enable_eplb": True,
"max_model_len": model_max_len,
}
return model_args
@pytest.mark.parametrize(
"model_setup",
[
pytest.param(
("mtp", "Qwen/Qwen3-Next-80B-A3B-Instruct", None, 4, 0.86),
marks=large_gpu_mark(min_gb=80),
),
pytest.param(
(
"eagle",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
4,
0.92,
),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues"),
),
],
ids=["qwen3_next_mtp", "llama4_eagle"],
)
def test_eplb_spec_decode(
monkeypatch: pytest.MonkeyPatch,
model_setup: tuple[str, str, str, int, float],
):
"""
Test the correctness of EPLB speculative decoding with GSM8K dataset.
Applicable to MoE models with mtp or eagle spec decode.
"""
method, model_name, spec_model_name, tp_size, expected_gsm8k_value = model_setup
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
model_args = get_model_args(
model_name=model_name,
spec_model_name=spec_model_name,
spec_method=method,
tp_size=tp_size,
model_max_len=4096,
)
results = lm_eval.simple_evaluate(
model="vllm",
model_args=model_args,
tasks=TASK,
batch_size=64,
num_fewshot=8,
)
measured_value = results["results"][TASK][FILTER]
assert (
measured_value - RTOL < expected_gsm8k_value
and measured_value + RTOL > expected_gsm8k_value
), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"
This diff is collapsed.
...@@ -226,7 +226,7 @@ class ToolParserManager: ...@@ -226,7 +226,7 @@ class ToolParserManager:
if isinstance(name, str): if isinstance(name, str):
names = [name] names = [name]
elif is_list_of(name, str): elif name is not None and is_list_of(name, str):
names = name names = name
else: else:
names = [class_name] names = [class_name]
......
...@@ -24,9 +24,12 @@ from vllm.model_executor.models.deepseek_v2 import ( ...@@ -24,9 +24,12 @@ from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2DecoderLayer, DeepseekV2DecoderLayer,
DeepseekV3ForCausalLM, DeepseekV3ForCausalLM,
) )
from vllm.utils import init_logger
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, maybe_prefix
logger = init_logger(__name__)
@support_torch_compile @support_torch_compile
class DeepseekV2Model(nn.Module): class DeepseekV2Model(nn.Module):
...@@ -215,6 +218,10 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): ...@@ -215,6 +218,10 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
self.config.vocab_size, scale=logit_scale self.config.vocab_size, scale=logit_scale
) )
# Set MoE hyperparameters
self.num_moe_layers = self.config.num_hidden_layers
self.set_moe_parameters()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -8,6 +8,7 @@ from transformers import PretrainedConfig ...@@ -8,6 +8,7 @@ from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -25,11 +26,15 @@ from vllm.sequence import IntermediateTensors ...@@ -25,11 +26,15 @@ from vllm.sequence import IntermediateTensors
from .deepseek_v2 import ( from .deepseek_v2 import (
DeepseekV2DecoderLayer, DeepseekV2DecoderLayer,
DeepseekV2MixtureOfExperts,
DeepseekV2MoE,
get_spec_layer_idx_from_weight_name, get_spec_layer_idx_from_weight_name,
) )
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import maybe_prefix from .utils import maybe_prefix
logger = init_logger(__name__)
class SharedHead(nn.Module): class SharedHead(nn.Module):
def __init__( def __init__(
...@@ -119,6 +124,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -119,6 +124,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
self.mtp_start_layer_idx = config.num_hidden_layers self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights # to map the exact layer index from weights
self.layers = torch.nn.ModuleDict( self.layers = torch.nn.ModuleDict(
{ {
str(idx): DeepSeekMultiTokenPredictorLayer( str(idx): DeepSeekMultiTokenPredictorLayer(
...@@ -172,13 +178,33 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -172,13 +178,33 @@ class DeepSeekMultiTokenPredictor(nn.Module):
@support_torch_compile @support_torch_compile
class DeepSeekMTP(nn.Module, SupportsPP): class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.model = DeepSeekMultiTokenPredictor( self.model = DeepSeekMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
# Set MoE hyperparameters
self.set_moe_parameters()
def set_moe_parameters(self):
self.expert_weights = []
self.num_moe_layers = self.config.num_nextn_predict_layers
self.num_expert_groups = self.config.n_group
self.moe_layers = []
self.moe_mlp_layers = []
example_moe = None
for layer in self.model.layers.values():
assert isinstance(layer, DeepSeekMultiTokenPredictorLayer)
layer = layer.mtp_block
assert isinstance(layer, DeepseekV2DecoderLayer)
if isinstance(layer.mlp, DeepseekV2MoE):
example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts)
self.extract_moe_parameters(example_moe)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -166,7 +166,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -166,7 +166,7 @@ class DeepseekV2MoE(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts self.n_shared_experts: int = config.n_shared_experts
...@@ -1122,7 +1122,6 @@ class DeepseekV2Model(nn.Module): ...@@ -1122,7 +1122,6 @@ class DeepseekV2Model(nn.Module):
) )
else: else:
self.embed_tokens = PPMissingLayer() self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer( lambda prefix: DeepseekV2DecoderLayer(
...@@ -1172,7 +1171,50 @@ class DeepseekV2Model(nn.Module): ...@@ -1172,7 +1171,50 @@ class DeepseekV2Model(nn.Module):
return hidden_states return hidden_states
class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoRA): class DeepseekV2MixtureOfExperts(MixtureOfExperts):
moe_mlp_layers: list[DeepseekV2MoE]
"""
List of MoE MLP layers in the model.
"""
def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None):
if example_moe is None:
self.num_moe_layers = 0
self.num_expert_groups = 0
self.num_logical_experts = 0
self.num_physical_experts = 0
self.num_local_physical_experts = 0
self.num_routed_experts = 0
self.num_shared_experts = 0
self.num_redundant_experts = 0
logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.")
else:
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for moe in self.moe_mlp_layers:
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
class DeepseekV2ForCausalLM(
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA
):
packed_modules_mapping = { packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
} }
...@@ -1213,13 +1255,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR ...@@ -1213,13 +1255,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
# Set MoE hyperparameters
self.num_moe_layers = (
self.config.num_hidden_layers - self.config.first_k_dense_replace
)
self.set_moe_parameters()
def set_moe_parameters(self):
self.expert_weights = [] self.expert_weights = []
# Set MoE hyperparameters self.num_expert_groups = self.config.n_group
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_expert_groups = config.n_group
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
self.moe_mlp_layers = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
...@@ -1229,50 +1277,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR ...@@ -1229,50 +1277,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
if isinstance(layer.mlp, DeepseekV2MoE): if isinstance(layer.mlp, DeepseekV2MoE):
# Pick last one layer since the first ones may be dense layers. # Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts) self.moe_layers.append(layer.mlp.experts)
if example_moe is None: self.extract_moe_parameters(example_moe)
raise RuntimeError("No DeepseekV2MoE layer found in model.layers.")
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.model.layers:
if isinstance(layer.mlp, DeepseekV2MoE):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -133,7 +133,7 @@ class Ernie4_5_MoeMoE(nn.Module): ...@@ -133,7 +133,7 @@ class Ernie4_5_MoeMoE(nn.Module):
self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", None) self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", None)
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.moe_num_experts self.n_routed_experts: int = config.moe_num_experts
self.n_shared_experts: int = self.moe_num_shared_experts self.n_shared_experts: int = self.moe_num_shared_experts
...@@ -709,22 +709,6 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe ...@@ -709,22 +709,6 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe
self.num_shared_experts = example_moe.n_shared_experts self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,
......
...@@ -62,7 +62,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -62,7 +62,7 @@ from vllm.model_executor.model_loader.weight_utils import (
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
...@@ -127,7 +127,7 @@ class Glm4MoE(nn.Module): ...@@ -127,7 +127,7 @@ class Glm4MoE(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts self.n_shared_experts: int = config.n_shared_experts
...@@ -616,7 +616,35 @@ class Glm4MoeModel(nn.Module): ...@@ -616,7 +616,35 @@ class Glm4MoeModel(nn.Module):
return loaded_params return loaded_params
class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): class Glm4MixtureOfExperts(MixtureOfExperts):
def extract_moe_parameters(self, example_moe: Glm4MoE | None) -> None:
if example_moe is None:
raise RuntimeError("No Glm4MoE layer found in model.layers.")
else:
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for moe in self.moe_mlp_layers:
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, Glm4MixtureOfExperts):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -659,7 +687,9 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -659,7 +687,9 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_expert_groups = config.n_group self.num_expert_groups = config.n_group
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
self.moe_mlp_layers: list[Glm4MoE] = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
...@@ -669,33 +699,10 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -669,33 +699,10 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
if isinstance(layer.mlp, Glm4MoE): if isinstance(layer.mlp, Glm4MoE):
# Pick last one layer since the first ones may be dense layers. # Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts) self.moe_layers.append(layer.mlp.experts)
if example_moe is None: self.extract_moe_parameters(example_moe)
raise RuntimeError("No Glm4MoE layer found in model.layers.")
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -29,7 +29,7 @@ import torch ...@@ -29,7 +29,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -41,7 +41,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -41,7 +41,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name from .glm4_moe import (
Glm4MixtureOfExperts,
Glm4MoE,
Glm4MoeDecoderLayer,
get_spec_layer_idx_from_weight_name,
)
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import maybe_prefix from .utils import maybe_prefix
...@@ -73,6 +78,7 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module): ...@@ -73,6 +78,7 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module):
prefix: str, prefix: str,
cache_config: CacheConfig | None = None, cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
parallel_config: ParallelConfig | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -81,11 +87,13 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module): ...@@ -81,11 +87,13 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module):
self.shared_head = SharedHead( self.shared_head = SharedHead(
config=config, prefix=prefix, quant_config=quant_config config=config, prefix=prefix, quant_config=quant_config
) )
self.enable_eplb = parallel_config.enable_eplb
self.mtp_block = Glm4MoeDecoderLayer( self.mtp_block = Glm4MoeDecoderLayer(
config=config, config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
enable_eplb=self.enable_eplb,
) )
def forward( def forward(
...@@ -127,6 +135,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module): ...@@ -127,6 +135,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
f"{prefix}.layers.{idx}", f"{prefix}.layers.{idx}",
cache_config=vllm_config.cache_config, cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config, quant_config=vllm_config.quant_config,
parallel_config=vllm_config.parallel_config,
) )
for idx in range( for idx in range(
self.mtp_start_layer_idx, self.mtp_start_layer_idx,
...@@ -175,7 +184,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module): ...@@ -175,7 +184,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
return logits return logits
class Glm4MoeMTP(nn.Module, SupportsPP): class Glm4MoeMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
...@@ -183,6 +192,25 @@ class Glm4MoeMTP(nn.Module, SupportsPP): ...@@ -183,6 +192,25 @@ class Glm4MoeMTP(nn.Module, SupportsPP):
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
self.expert_weights = []
# Set MoE hyperparameters
self.num_moe_layers = self.config.num_nextn_predict_layers
self.num_expert_groups = self.config.n_group
self.moe_layers: list[FusedMoE] = []
self.moe_mlp_layers: list[Glm4MoE] = []
example_moe = None
for layer in self.model.layers.values():
assert isinstance(layer, Glm4MoeMultiTokenPredictorLayer)
layer = layer.mtp_block
assert isinstance(layer, Glm4MoeDecoderLayer)
if isinstance(layer.mlp, Glm4MoE):
example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts)
self.extract_moe_parameters(example_moe)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -374,7 +374,7 @@ class HunYuanSparseMoeBlock(nn.Module): ...@@ -374,7 +374,7 @@ class HunYuanSparseMoeBlock(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
...@@ -1007,7 +1007,7 @@ class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts): ...@@ -1007,7 +1007,7 @@ class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):
# Set MoE hyperparameters # Set MoE hyperparameters
self.expert_weights = [] self.expert_weights = []
self.num_expert_groups = 1 self.num_expert_groups = 1
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
example_layer = None example_layer = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
...@@ -1028,22 +1028,6 @@ class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts): ...@@ -1028,22 +1028,6 @@ class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):
self.num_routed_experts = example_layer.n_routed_experts self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts self.num_redundant_experts = example_layer.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
self.expert_weights.append(layer.get_expert_weights())
# Register the expert weights.
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,
......
...@@ -14,6 +14,7 @@ from typing import ( ...@@ -14,6 +14,7 @@ from typing import (
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from torch import Tensor from torch import Tensor
from transformers import PretrainedConfig from transformers import PretrainedConfig
from transformers.models.whisper.tokenization_whisper import LANGUAGES from transformers.models.whisper.tokenization_whisper import LANGUAGES
...@@ -641,6 +642,9 @@ class MixtureOfExperts(Protocol): ...@@ -641,6 +642,9 @@ class MixtureOfExperts(Protocol):
num_redundant_experts: int num_redundant_experts: int
"""Number of redundant experts in this model.""" """Number of redundant experts in this model."""
moe_layers: Iterable[nn.Module]
"""List of MoE layers in this model."""
def set_eplb_state( def set_eplb_state(
self, self,
expert_load_view: Tensor, expert_load_view: Tensor,
...@@ -663,7 +667,15 @@ class MixtureOfExperts(Protocol): ...@@ -663,7 +667,15 @@ class MixtureOfExperts(Protocol):
logical_to_physical_map: Mapping from logical to physical experts. logical_to_physical_map: Mapping from logical to physical experts.
logical_replica_count: Count of replicas for each logical expert. logical_replica_count: Count of replicas for each logical expert.
""" """
... for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
......
...@@ -105,7 +105,7 @@ class Lfm2MoeSparseMoeBlock(nn.Module): ...@@ -105,7 +105,7 @@ class Lfm2MoeSparseMoeBlock(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
...@@ -707,7 +707,7 @@ class Lfm2MoeForCausalLM( ...@@ -707,7 +707,7 @@ class Lfm2MoeForCausalLM(
# Set MoE hyperparameters # Set MoE hyperparameters
self.expert_weights = [] self.expert_weights = []
self.moe_layers: list[FusedMoE] = [] self.moe_layers = []
example_layer = None example_layer = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
...@@ -737,22 +737,6 @@ class Lfm2MoeForCausalLM( ...@@ -737,22 +737,6 @@ class Lfm2MoeForCausalLM(
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,
......
...@@ -30,9 +30,11 @@ from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention ...@@ -30,9 +30,11 @@ from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import ( from vllm.distributed import (
get_ep_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
...@@ -46,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -46,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from vllm.model_executor.models.interfaces import MixtureOfExperts
from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.models.utils import sequence_parallel_chunk
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
...@@ -56,6 +59,8 @@ from .utils import ( ...@@ -56,6 +59,8 @@ from .utils import (
is_pp_missing_parameter, is_pp_missing_parameter,
) )
logger = init_logger(__name__)
class Llama4MoE(nn.Module): class Llama4MoE(nn.Module):
@staticmethod @staticmethod
...@@ -80,6 +85,9 @@ class Llama4MoE(nn.Module): ...@@ -80,6 +85,9 @@ class Llama4MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self.ep_group = get_ep_group().device_group
self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size()
intermediate_size_moe = config.intermediate_size intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear( self.router = ReplicatedLinear(
...@@ -101,6 +109,20 @@ class Llama4MoE(nn.Module): ...@@ -101,6 +109,20 @@ class Llama4MoE(nn.Module):
disable_tp=self.is_sequence_parallel, disable_tp=self.is_sequence_parallel,
) )
# Load balancing settings.
eplb_config = parallel_config.eplb_config if parallel_config else None
self.enable_eplb = parallel_config.enable_eplb if parallel_config else False
self.n_redundant_experts = (
eplb_config.num_redundant_experts if eplb_config else 0
)
self.n_routed_experts: int = config.num_local_experts
self.n_logical_experts = self.n_routed_experts
self.n_shared_experts: int = 1
self.n_local_experts: int = config.num_local_experts
self.n_physical_experts = self.n_local_experts + self.n_redundant_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.experts = SharedFusedMoE( self.experts = SharedFusedMoE(
shared_experts=self.shared_expert, shared_experts=self.shared_expert,
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
...@@ -114,6 +136,8 @@ class Llama4MoE(nn.Module): ...@@ -114,6 +136,8 @@ class Llama4MoE(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
is_sequence_parallel=self.is_sequence_parallel, is_sequence_parallel=self.is_sequence_parallel,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
) )
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -378,6 +402,9 @@ class Llama4Model(LlamaModel): ...@@ -378,6 +402,9 @@ class Llama4Model(LlamaModel):
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer, layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer,
): ):
self.num_experts = vllm_config.model_config.hf_config.num_local_experts self.num_experts = vllm_config.model_config.hf_config.num_local_experts
self.n_redundant_experts = (
vllm_config.parallel_config.eplb_config.num_redundant_experts
)
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
def load_moe_expert_weights( def load_moe_expert_weights(
...@@ -499,7 +526,6 @@ class Llama4Model(LlamaModel): ...@@ -499,7 +526,6 @@ class Llama4Model(LlamaModel):
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id, expert_id=expert_id,
) )
loaded_params.add(full_param_name) loaded_params.add(full_param_name)
expert_param_loaded = True expert_param_loaded = True
...@@ -526,6 +552,7 @@ class Llama4Model(LlamaModel): ...@@ -526,6 +552,7 @@ class Llama4Model(LlamaModel):
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.num_experts, num_experts=self.num_experts,
num_redundant_experts=self.n_redundant_experts,
) )
# Expert parameter mapping for the case where the expert weights are # Expert parameter mapping for the case where the expert weights are
# fused into a single weight tensor. # fused into a single weight tensor.
...@@ -683,7 +710,7 @@ class Llama4Model(LlamaModel): ...@@ -683,7 +710,7 @@ class Llama4Model(LlamaModel):
return loaded_params return loaded_params
class Llama4ForCausalLM(LlamaForCausalLM): class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
...@@ -702,6 +729,57 @@ class Llama4ForCausalLM(LlamaForCausalLM): ...@@ -702,6 +729,57 @@ class Llama4ForCausalLM(LlamaForCausalLM):
super().__init__( super().__init__(
vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer
) )
# Set MoE hyperparameters
self.set_moe_parameters()
def set_moe_parameters(self):
self.expert_weights = []
self.moe_layers = []
example_moe = None
for layer in self.model.layers:
assert isinstance(layer, Llama4DecoderLayer)
if isinstance(layer.feed_forward, Llama4MoE):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.feed_forward
self.moe_layers.append(layer.feed_forward.experts)
if example_moe is None:
self.num_moe_layers = 0
self.num_expert_groups = 0
self.num_logical_experts = 0
self.num_physical_experts = 0
self.num_local_physical_experts = 0
self.num_routed_experts = 0
self.num_shared_experts = 0
self.num_redundant_experts = 0
logger.warning("No Llama4MoE layer found in model.layers.")
else:
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.model.layers:
if isinstance(layer.feed_forward, Llama4MoE):
moe = layer.feed_forward
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def _init_model( def _init_model(
self, self,
......
...@@ -189,6 +189,9 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): ...@@ -189,6 +189,9 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
self.config.vocab_size, scale=logit_scale self.config.vocab_size, scale=logit_scale
) )
# Set MoE hyperparameters
self.set_moe_parameters()
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.model return self.model
......
...@@ -578,6 +578,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): ...@@ -578,6 +578,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
parallel_config = vllm_config.parallel_config
self.prefix = prefix self.prefix = prefix
self.vllm_config = vllm_config self.vllm_config = vllm_config
...@@ -613,6 +614,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): ...@@ -613,6 +614,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
if parallel_config.enable_eplb and getattr(config, "num_experts", 0) > 0:
raise NotImplementedError("EPLB is not supported for MiniCPM yet.")
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
return MiniCPMModel(vllm_config=vllm_config, prefix=prefix) return MiniCPMModel(vllm_config=vllm_config, prefix=prefix)
......
...@@ -98,7 +98,7 @@ class MixtralMoE(nn.Module): ...@@ -98,7 +98,7 @@ class MixtralMoE(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
# Expert Parallelism Load balancing settings. # Expert Parallelism Load balancing settings.
...@@ -546,7 +546,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -546,7 +546,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
) )
self.expert_weights = [] self.expert_weights = []
self.moe_layers: list[FusedMoE] = [] self.moe_layers = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
...@@ -572,22 +572,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -572,22 +572,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
self.num_expert_groups = 1 self.num_expert_groups = 1
self.num_shared_experts = 0 self.num_shared_experts = 0
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,
......
...@@ -65,6 +65,7 @@ from vllm.sequence import IntermediateTensors ...@@ -65,6 +65,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
MixtureOfExperts,
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsEagle3, SupportsEagle3,
SupportsMultiModal, SupportsMultiModal,
...@@ -723,7 +724,7 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): ...@@ -723,7 +724,7 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
dummy_inputs=Mllama4DummyInputsBuilder, dummy_inputs=Mllama4DummyInputsBuilder,
) )
class Llama4ForConditionalGeneration( class Llama4ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3 nn.Module, SupportsMultiModal, SupportsPP, MixtureOfExperts, SupportsEagle3
): ):
merge_by_field_config = True merge_by_field_config = True
...@@ -776,6 +777,17 @@ class Llama4ForConditionalGeneration( ...@@ -776,6 +777,17 @@ class Llama4ForConditionalGeneration(
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
# Set MoE hyperparameters
self.num_expert_groups = 1
self.num_logical_experts = self.language_model.num_logical_experts
self.num_physical_experts = self.language_model.num_physical_experts
self.num_local_physical_experts = self.language_model.num_local_physical_experts
self.num_routed_experts = self.language_model.num_routed_experts
self.num_shared_experts = self.language_model.num_shared_experts
self.num_redundant_experts = self.language_model.num_redundant_experts
self.moe_layers = self.language_model.moe_layers
self.num_moe_layers = len(self.moe_layers)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""Set which layers should output auxiliary hidden states for EAGLE3.""" """Set which layers should output auxiliary hidden states for EAGLE3."""
# Delegate to underlying language model (Llama4ForCausalLM) # Delegate to underlying language model (Llama4ForCausalLM)
...@@ -792,6 +804,24 @@ class Llama4ForConditionalGeneration( ...@@ -792,6 +804,24 @@ class Llama4ForConditionalGeneration(
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers") assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
return self.language_model.get_eagle3_aux_hidden_state_layers() return self.language_model.get_eagle3_aux_hidden_state_layers()
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
):
self.language_model.set_eplb_state(
expert_load_view, logical_to_physical_map, logical_replica_count
)
self.expert_weights = self.language_model.expert_weights
def update_physical_experts_metadata(
self, num_physical_experts: int, num_local_physical_experts: int
):
self.language_model.update_physical_experts_metadata(
num_physical_experts, num_local_physical_experts
)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
) -> Llama4ImagePatchInputs | None: ) -> Llama4ImagePatchInputs | None:
......
...@@ -807,7 +807,7 @@ class NemotronHForCausalLM( ...@@ -807,7 +807,7 @@ class NemotronHForCausalLM(
self.expert_weights = [] self.expert_weights = []
self.num_expert_groups = config.n_group self.num_expert_groups = config.n_group
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, NemotronHMoEDecoderLayer): if isinstance(layer, NemotronHMoEDecoderLayer):
...@@ -824,22 +824,6 @@ class NemotronHForCausalLM( ...@@ -824,22 +824,6 @@ class NemotronHForCausalLM(
self.num_shared_experts = example_moe.n_shared_experts self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,
......
...@@ -1009,7 +1009,7 @@ class OpenPanguMoEModel(OpenPanguModelBase, MixtureOfExperts): ...@@ -1009,7 +1009,7 @@ class OpenPanguMoEModel(OpenPanguModelBase, MixtureOfExperts):
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_expert_groups = 1 self.num_expert_groups = 1
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
...@@ -1031,22 +1031,6 @@ class OpenPanguMoEModel(OpenPanguModelBase, MixtureOfExperts): ...@@ -1031,22 +1031,6 @@ class OpenPanguMoEModel(OpenPanguModelBase, MixtureOfExperts):
self.n_shared_experts = example_moe.n_shared_experts self.n_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment