Unverified Commit e50c4546 authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

[BugFix] Support EP/DP + EPLB with MTP (#25311)


Signed-off-by: default avatarilmarkov <markovilya197@gmail.com>
Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
Co-authored-by: default avatarSage Moore <sage@neuralmagic.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
parent 5d16d0fa
...@@ -132,7 +132,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -132,7 +132,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
...@@ -665,7 +665,7 @@ class Qwen3MoeForCausalLM( ...@@ -665,7 +665,7 @@ class Qwen3MoeForCausalLM(
# Set MoE hyperparameters # Set MoE hyperparameters
self.expert_weights = [] self.expert_weights = []
self.moe_layers: list[FusedMoE] = [] self.moe_layers = []
example_layer = None example_layer = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
...@@ -688,22 +688,6 @@ class Qwen3MoeForCausalLM( ...@@ -688,22 +688,6 @@ class Qwen3MoeForCausalLM(
self.num_routed_experts = example_layer.n_routed_experts self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts self.num_redundant_experts = example_layer.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,
......
...@@ -107,7 +107,7 @@ class Qwen3NextSparseMoeBlock(nn.Module): ...@@ -107,7 +107,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
...@@ -1095,8 +1095,57 @@ class Qwen3NextModel(nn.Module): ...@@ -1095,8 +1095,57 @@ class Qwen3NextModel(nn.Module):
return loaded_params return loaded_params
class QwenNextMixtureOfExperts(MixtureOfExperts):
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.model.layers:
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def set_moe_parameters(self):
self.expert_weights = []
self.moe_layers = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
layer.mlp, Qwen3NextSparseMoeBlock
):
example_moe = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_moe is None:
raise RuntimeError("No Qwen3Next layer found in the model.layers.")
# Set MoE hyperparameters
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_redundant_experts = example_moe.n_redundant_experts
class Qwen3NextForCausalLM( class Qwen3NextForCausalLM(
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, MixtureOfExperts, IsHybrid nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
QwenNextMixtureOfExperts,
IsHybrid,
): ):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
...@@ -1147,63 +1196,7 @@ class Qwen3NextForCausalLM( ...@@ -1147,63 +1196,7 @@ class Qwen3NextForCausalLM(
) )
# Set MoE hyperparameters # Set MoE hyperparameters
self.expert_weights = [] self.set_moe_parameters()
self.moe_layers: list[SharedFusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Qwen3NextDecoderLayer)
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No Qwen3Next layer found in the model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.model.layers:
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -23,6 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -23,6 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3_next import ( from vllm.model_executor.models.qwen3_next import (
Qwen3NextDecoderLayer, Qwen3NextDecoderLayer,
Qwen3NextRMSNorm, Qwen3NextRMSNorm,
QwenNextMixtureOfExperts,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.transformers_utils.configs import Qwen3NextConfig
...@@ -226,7 +227,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module): ...@@ -226,7 +227,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
@support_torch_compile @support_torch_compile
class Qwen3NextMTP(nn.Module, SupportsPP): class Qwen3NextMTP(nn.Module, SupportsPP, QwenNextMixtureOfExperts):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -265,6 +266,7 @@ class Qwen3NextMTP(nn.Module, SupportsPP): ...@@ -265,6 +266,7 @@ class Qwen3NextMTP(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
self.set_moe_parameters()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -125,7 +125,7 @@ class MoEMixin(MixtureOfExperts): ...@@ -125,7 +125,7 @@ class MoEMixin(MixtureOfExperts):
logical_to_physical_map: torch.Tensor, logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor, logical_replica_count: torch.Tensor,
): ):
for moe_layer_idx, mlp_layer in enumerate(self.mlp_layers): for moe_layer_idx, mlp_layer in enumerate(self.mlp_moe_layers):
mlp_layer.experts.set_eplb_state( mlp_layer.experts.set_eplb_state(
moe_layer_idx=moe_layer_idx, moe_layer_idx=moe_layer_idx,
expert_load_view=expert_load_view, expert_load_view=expert_load_view,
...@@ -142,7 +142,7 @@ class MoEMixin(MixtureOfExperts): ...@@ -142,7 +142,7 @@ class MoEMixin(MixtureOfExperts):
self.num_physical_experts = num_physical_experts self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for mlp in self.mlp_layers: for mlp in self.mlp_moe_layers:
mlp.n_local_physical_experts = num_local_physical_experts mlp.n_local_physical_experts = num_local_physical_experts
mlp.n_physical_experts = num_physical_experts mlp.n_physical_experts = num_physical_experts
mlp.n_redundant_experts = self.num_redundant_experts mlp.n_redundant_experts = self.num_redundant_experts
...@@ -240,7 +240,8 @@ class MoEMixin(MixtureOfExperts): ...@@ -240,7 +240,8 @@ class MoEMixin(MixtureOfExperts):
# MixtureOfExperts mixin settings # MixtureOfExperts mixin settings
ep_size = get_ep_group().world_size ep_size = get_ep_group().world_size
self.mlp_layers = [] # Used for MixtureOfExperts methods self.mlp_moe_layers = [] # Used for MixtureOfExperts methods
self.moe_layers = []
self.expert_weights = [] self.expert_weights = []
self.num_moe_layers = 0 self.num_moe_layers = 0
self.num_expert_groups = 1 if num_expert_group is None else num_expert_group self.num_expert_groups = 1 if num_expert_group is None else num_expert_group
...@@ -298,7 +299,8 @@ class MoEMixin(MixtureOfExperts): ...@@ -298,7 +299,8 @@ class MoEMixin(MixtureOfExperts):
mlp.experts = fused_experts mlp.experts = fused_experts
log_replacement(qual_name, experts, fused_experts) log_replacement(qual_name, experts, fused_experts)
# Update MixtureOfExperts mixin state # Update MixtureOfExperts mixin state
self.mlp_layers.append(mlp) self.mlp_moe_layers.append(mlp)
self.moe_layers.append(fused_experts)
self.expert_weights.append(fused_experts.get_expert_weights()) self.expert_weights.append(fused_experts.get_expert_weights())
self.num_moe_layers += 1 self.num_moe_layers += 1
# If results are not all-reduced in FusedMoE, ensure they # If results are not all-reduced in FusedMoE, ensure they
......
...@@ -8,6 +8,7 @@ from vllm.config import VllmConfig ...@@ -8,6 +8,7 @@ from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
# Initialize logger # Initialize logger
...@@ -56,6 +57,10 @@ class MedusaProposer: ...@@ -56,6 +57,10 @@ class MedusaProposer:
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
model_config=self.vllm_config.speculative_config.draft_model_config, model_config=self.vllm_config.speculative_config.draft_model_config,
) )
assert not (
is_mixture_of_experts(self.model)
and self.vllm_config.parallel_config.enable_eplb
), "EPLB for Medusa is not supported"
@torch.inference_mode() @torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None: def dummy_run(self, num_tokens: int) -> None:
......
...@@ -2046,7 +2046,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2046,7 +2046,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model = self.get_model() model = self.get_model()
assert is_mixture_of_experts(model) assert is_mixture_of_experts(model)
self.eplb_state.step( self.eplb_state.step(
model,
is_dummy, is_dummy,
is_profile, is_profile,
log_stats=self.parallel_config.eplb_config.log_balancedness, log_stats=self.parallel_config.eplb_config.log_balancedness,
...@@ -2803,7 +2802,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2803,7 +2802,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else: else:
indices = [] indices = []
offset = 0 offset = 0
assert spec_decode_metadata is not None assert spec_decode_metadata is not None, (
"No spec decode metadata for medusa"
)
for num_draft, tokens in zip( for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens, sampled_token_ids spec_decode_metadata.num_draft_tokens, sampled_token_ids
): ):
...@@ -2934,32 +2935,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2934,32 +2935,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.model_config.model, self.model_config.model,
scope="global", scope="global",
) )
if eep_scale_up: global_expert_loads, old_global_expert_indices_per_model, rank_mapping = (
from vllm.distributed.parallel_state import get_ep_group EplbState.get_eep_state(self.parallel_config)
if eep_scale_up
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") else (None, None, None)
torch.distributed.broadcast( )
num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
)
num_local_physical_experts = int(num_local_physical_experts.item())
new_ep_size = get_ep_group().world_size
global_expert_load, old_global_expert_indices = EplbState.recv_state()
num_logical_experts = global_expert_load.shape[1]
self.parallel_config.eplb_config.num_redundant_experts = (
num_local_physical_experts * new_ep_size - num_logical_experts
)
assert old_global_expert_indices.shape[1] % num_local_physical_experts == 0
old_ep_size = (
old_global_expert_indices.shape[1] // num_local_physical_experts
)
rank_mapping = {
old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)
}
else:
global_expert_load = None
old_global_expert_indices = None
rank_mapping = None
if self.parallel_config.enable_eplb:
self.eplb_state = EplbState(self.parallel_config, self.device)
eplb_models = 0
with DeviceMemoryProfiler() as m: with DeviceMemoryProfiler() as m:
time_before_load = time.perf_counter() time_before_load = time.perf_counter()
model_loader = get_model_loader(self.load_config) model_loader = get_model_loader(self.load_config)
...@@ -2971,8 +2955,39 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2971,8 +2955,39 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.model, self.vllm_config, self.device self.model, self.vllm_config, self.device
) )
if hasattr(self, "drafter"): if hasattr(self, "drafter"):
logger.info("Loading drafter model...") logger.info_once("Loading drafter model...")
self.drafter.load_model(self.model) self.drafter.load_model(self.model)
if (
hasattr(self.drafter, "model")
and is_mixture_of_experts(self.drafter.model)
and self.parallel_config.enable_eplb
):
logger.info_once(
"EPLB is enabled for drafter model %s.",
self.vllm_config.speculative_config.draft_model_config.model,
)
global_expert_load = (
global_expert_loads[eplb_models]
if global_expert_loads
else None
)
old_global_expert_indices = (
old_global_expert_indices_per_model[eplb_models]
if old_global_expert_indices_per_model
else None
)
if self.eplb_state is None:
self.eplb_state = EplbState(self.parallel_config, self.device)
self.eplb_state.add_model(
self.drafter.model,
self.vllm_config.speculative_config.draft_model_config,
global_expert_load,
old_global_expert_indices,
rank_mapping,
)
eplb_models += 1
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
if not supports_eagle3(self.get_model()): if not supports_eagle3(self.get_model()):
raise RuntimeError( raise RuntimeError(
...@@ -3001,18 +3016,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3001,18 +3016,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scope="local", scope="local",
) )
prepare_communication_buffer_for_model(self.model) prepare_communication_buffer_for_model(self.model)
self.is_multimodal_pruning_enabled = ( self.is_multimodal_pruning_enabled = (
supports_multimodal_pruning(self.get_model()) supports_multimodal_pruning(self.get_model())
and self.model_config.multimodal_config.is_multimodal_pruning_enabled() and self.model_config.multimodal_config.is_multimodal_pruning_enabled()
) )
if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb:
logger.info("EPLB is enabled for model %s.", self.model_config.model) logger.info_once("EPLB is enabled for model %s.", self.model_config.model)
self.eplb_state = EplbState.build( global_expert_load = (
global_expert_loads[eplb_models] if global_expert_loads else None
)
old_global_expert_indices = (
old_global_expert_indices_per_model[eplb_models]
if old_global_expert_indices_per_model
else None
)
assert self.eplb_state is not None
self.eplb_state.add_model(
self.model, self.model,
self.device, self.model_config,
self.parallel_config,
global_expert_load, global_expert_load,
old_global_expert_indices, old_global_expert_indices,
rank_mapping, rank_mapping,
......
...@@ -32,6 +32,7 @@ from vllm.distributed.parallel_state import ( ...@@ -32,6 +32,7 @@ from vllm.distributed.parallel_state import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -613,7 +614,6 @@ class Worker(WorkerBase): ...@@ -613,7 +614,6 @@ class Worker(WorkerBase):
} }
assert self.model_runner.eplb_state is not None assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange( self.model_runner.eplb_state.rearrange(
self.model_runner.model,
execute_shuffle=True, execute_shuffle=True,
global_expert_load=None, global_expert_load=None,
rank_mapping=rank_mapping, rank_mapping=rank_mapping,
...@@ -626,7 +626,7 @@ class Worker(WorkerBase): ...@@ -626,7 +626,7 @@ class Worker(WorkerBase):
self, self,
old_ep_size: int, old_ep_size: int,
new_ep_size: int, new_ep_size: int,
global_expert_load: torch.Tensor | None, global_expert_loads: list[torch.Tensor] | None,
) -> None: ) -> None:
from vllm.distributed.parallel_state import get_ep_group from vllm.distributed.parallel_state import get_ep_group
...@@ -635,9 +635,8 @@ class Worker(WorkerBase): ...@@ -635,9 +635,8 @@ class Worker(WorkerBase):
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)} rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
assert self.model_runner.eplb_state is not None assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange( self.model_runner.eplb_state.rearrange(
self.model_runner.model,
execute_shuffle=True, execute_shuffle=True,
global_expert_load=global_expert_load, global_expert_loads=global_expert_loads,
rank_mapping=rank_mapping, rank_mapping=rank_mapping,
) )
if get_ep_group().rank == 0: if get_ep_group().rank == 0:
...@@ -684,31 +683,56 @@ class Worker(WorkerBase): ...@@ -684,31 +683,56 @@ class Worker(WorkerBase):
get_ep_group, get_ep_group,
prepare_communication_buffer_for_model, prepare_communication_buffer_for_model,
) )
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEParallelConfig,
)
parallel_config = self.vllm_config.parallel_config parallel_config = self.vllm_config.parallel_config
moe_modules = [
module def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
for module in self.model_runner.model.modules() return [
if ( module
module.__class__.__name__ == "FusedMoE" for module in model.modules()
or module.__class__.__name__ == "SharedFusedMoE" if (
) module.__class__.__name__ == "FusedMoE"
] or module.__class__.__name__ == "SharedFusedMoE"
num_local_experts = moe_modules[0].moe_config.num_local_experts )
assert all( ]
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
), "All MoE modules must have the same number of experts" assert all(
for module in moe_modules: module.moe_config.num_local_experts == num_local_experts
module.moe_config.num_experts = num_local_experts * new_ep_size for module in moe_modules
module.global_num_experts = module.moe_config.num_experts ), "All MoE modules must have the same number of experts"
module.moe_parallel_config = FusedMoEParallelConfig.make( for module in moe_modules:
tp_size_=get_tp_group().world_size, module.moe_config.num_experts = num_local_experts * new_ep_size
dp_size_=get_dp_group().world_size, module.global_num_experts = module.moe_config.num_experts
vllm_parallel_config=parallel_config, module.moe_parallel_config = FusedMoEParallelConfig.make(
) tp_size_=get_tp_group().world_size,
module.moe_config.moe_parallel_config = module.moe_parallel_config dp_size_=get_dp_group().world_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
return moe_modules
model_moe_modules = get_moe_modules(self.model_runner.model)
num_local_experts = model_moe_modules[0].moe_config.num_local_experts
update_moe_modules(model_moe_modules, num_local_experts)
drafter_model = None
if hasattr(self.model_runner, "drafter") and hasattr(
self.model_runner.drafter, "model"
):
drafter_model = self.model_runner.drafter.model
if drafter_model is not None and is_mixture_of_experts(drafter_model):
drafter_moe_modules = get_moe_modules(drafter_model)
# Check if drafter and model have matching configs
assert (
drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
), "Drafter and model configs should be the same"
update_moe_modules(drafter_moe_modules, num_local_experts)
if new_ep_size < old_ep_size: if new_ep_size < old_ep_size:
num_local_physical_experts = num_local_experts num_local_physical_experts = num_local_experts
assert self.model_runner.eplb_state is not None assert self.model_runner.eplb_state is not None
...@@ -719,7 +743,7 @@ class Worker(WorkerBase): ...@@ -719,7 +743,7 @@ class Worker(WorkerBase):
new_physical_experts new_physical_experts
- self.model_runner.eplb_state.logical_replica_count.shape[1] - self.model_runner.eplb_state.logical_replica_count.shape[1]
) )
global_expert_load = None global_expert_loads = None
else: else:
num_local_physical_experts = torch.tensor( num_local_physical_experts = torch.tensor(
[num_local_experts], dtype=torch.int32, device="cpu" [num_local_experts], dtype=torch.int32, device="cpu"
...@@ -730,18 +754,20 @@ class Worker(WorkerBase): ...@@ -730,18 +754,20 @@ class Worker(WorkerBase):
num_local_physical_experts = num_local_physical_experts.item() num_local_physical_experts = num_local_physical_experts.item()
new_physical_experts = num_local_physical_experts * new_ep_size new_physical_experts = num_local_physical_experts * new_ep_size
assert self.model_runner.eplb_state is not None assert self.model_runner.eplb_state is not None
global_expert_load = self.model_runner.eplb_state.rearrange( global_expert_loads = self.model_runner.eplb_state.rearrange(
self.model_runner.model, execute_shuffle=False execute_shuffle=False
) )
parallel_config.eplb_config.num_redundant_experts = ( parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts - global_expert_load.shape[1] new_physical_experts - global_expert_loads[0].shape[1]
) )
prepare_communication_buffer_for_model(self.model_runner.model) prepare_communication_buffer_for_model(self.model_runner.model)
if drafter_model is not None:
prepare_communication_buffer_for_model(drafter_model)
self.model_runner.model.update_physical_experts_metadata( self.model_runner.model.update_physical_experts_metadata(
num_physical_experts=new_physical_experts, num_physical_experts=new_physical_experts,
num_local_physical_experts=num_local_physical_experts, num_local_physical_experts=num_local_physical_experts,
) )
return global_expert_load return global_expert_loads
def reinitialize_distributed( def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest self, reconfig_request: ReconfigureDistributedRequest
...@@ -782,11 +808,11 @@ class Worker(WorkerBase): ...@@ -782,11 +808,11 @@ class Worker(WorkerBase):
self.local_rank, self.local_rank,
) )
global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size) global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
if new_ep_size > old_ep_size: if new_ep_size > old_ep_size:
assert global_expert_load is not None assert global_expert_loads is not None
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load) self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
def save_sharded_state( def save_sharded_state(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment