Unverified Commit 21793722 authored by Rui Qiao's avatar Rui Qiao Committed by GitHub
Browse files

Elastic Expert Parallel Initial Support (#20775)


Signed-off-by: default avatarRui Qiao <ruisearch42@gmail.com>
parent 5782581a
......@@ -6,6 +6,7 @@ from typing import Union
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor as RayDistributedExecutorV0)
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
......@@ -62,3 +63,11 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return FutureWrapper(refs[0])
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest) -> None:
self._run_workers("reinitialize_distributed", reconfig_request)
if reconfig_request.new_data_parallel_rank == \
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
self.shutdown()
return
......@@ -49,7 +49,7 @@ class CPUModelRunner(GPUModelRunner):
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
replace_tensor(self.input_batch.block_table, k, k[:-4])
def load_model(self) -> None:
def load_model(self, eep_scale_up: bool = False) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config)
......
......@@ -1745,8 +1745,40 @@ class GPUModelRunner(LoRAModelRunnerMixin):
new_config = update_config(config, config_overrides)
setattr(self, config_name, new_config)
def load_model(self) -> None:
def load_model(self, eep_scale_up: bool = False) -> None:
"""
Args:
eep_scale_up: the model loading is for elastic EP scale up.
"""
logger.info("Starting to load model %s...", self.model_config.model)
if eep_scale_up:
from vllm.distributed.parallel_state import get_ep_group
num_local_physical_experts = torch.empty(1,
dtype=torch.int32,
device="cpu")
torch.distributed.broadcast(num_local_physical_experts,
group=get_ep_group().cpu_group,
group_src=0)
num_local_physical_experts = int(num_local_physical_experts.item())
new_ep_size = get_ep_group().world_size
global_expert_load, old_global_expert_indices = (
EplbState.recv_state())
num_logical_experts = global_expert_load.shape[1]
self.parallel_config.num_redundant_experts = (
num_local_physical_experts * new_ep_size - num_logical_experts)
assert old_global_expert_indices.shape[
1] % num_local_physical_experts == 0
old_ep_size = old_global_expert_indices.shape[
1] // num_local_physical_experts
rank_mapping = {
old_ep_rank: old_ep_rank
for old_ep_rank in range(old_ep_size)
}
else:
global_expert_load = None
old_global_expert_indices = None
rank_mapping = None
with DeviceMemoryProfiler() as m: # noqa: SIM117
time_before_load = time.perf_counter()
model_loader = get_model_loader(self.load_config)
......@@ -1788,6 +1820,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.model,
self.device,
self.parallel_config,
global_expert_load,
old_global_expert_indices,
rank_mapping,
)
def save_tensorized_model(
......
......@@ -26,6 +26,7 @@ from vllm.platforms import current_platform
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
......@@ -191,8 +192,9 @@ class Worker(WorkerBase):
else:
from contextlib import nullcontext
context = nullcontext()
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
with context:
self.model_runner.load_model()
self.model_runner.load_model(eep_scale_up=eep_scale_up)
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
......@@ -384,6 +386,161 @@ class Worker(WorkerBase):
# worker will always be healthy as long as it's running.
return
def _eplb_before_scale_down(self, old_ep_size: int,
new_ep_size: int) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding "
"before scaling down...")
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(self.model_runner.model,
execute_shuffle=True,
global_expert_load=None,
rank_mapping=rank_mapping)
torch.cuda.synchronize()
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _eplb_after_scale_up(
self, old_ep_size: int, new_ep_size: int,
global_expert_load: Optional[torch.Tensor]) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding "
"after scaling up...")
rank_mapping = {
old_ep_rank: old_ep_rank
for old_ep_rank in range(old_ep_size)
}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
self.model_runner.model,
execute_shuffle=True,
global_expert_load=global_expert_load,
rank_mapping=rank_mapping)
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _reconfigure_parallel_config(
self, reconfig_request: ReconfigureDistributedRequest) -> None:
"""
Update parallel config with provided reconfig_request
"""
parallel_config = self.vllm_config.parallel_config
parallel_config.data_parallel_size = \
reconfig_request.new_data_parallel_size
if reconfig_request.new_data_parallel_rank != \
ReconfigureRankType.KEEP_CURRENT_RANK:
parallel_config.data_parallel_rank = \
reconfig_request.new_data_parallel_rank
if reconfig_request.new_data_parallel_rank_local != \
ReconfigureRankType.KEEP_CURRENT_RANK:
parallel_config.data_parallel_rank_local = \
reconfig_request.new_data_parallel_rank_local
parallel_config.data_parallel_master_ip = \
reconfig_request.new_data_parallel_master_ip
parallel_config.data_parallel_master_port = \
reconfig_request.new_data_parallel_master_port
def _reconfigure_moe(self, old_ep_size: int,
new_ep_size: int) -> Optional[torch.Tensor]:
"""
Reconfigure MoE modules with provided reconfig_request
Return the global expert load if new_ep_size > old_ep_size,
otherwise None
"""
from vllm.distributed.parallel_state import (
get_dp_group, get_ep_group, prepare_communication_buffer_for_model)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoEParallelConfig)
parallel_config = self.vllm_config.parallel_config
moe_modules = [
module for module in self.model_runner.model.modules()
if module.__class__.__name__ == "FusedMoE"
]
num_local_experts = moe_modules[0].moe_config.num_local_experts
assert all(module.moe_config.num_local_experts == num_local_experts
for module in moe_modules), (
"All MoE modules must have the same number of experts")
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=get_tp_group().world_size,
dp_size_=get_dp_group().world_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
if new_ep_size < old_ep_size:
num_local_physical_experts = num_local_experts
assert self.model_runner.eplb_state is not None
new_physical_experts = \
self.model_runner.eplb_state.physical_to_logical_map.shape[1]
parallel_config.num_redundant_experts = (
new_physical_experts -
self.model_runner.eplb_state.logical_replica_count.shape[1])
global_expert_load = None
else:
num_local_physical_experts = torch.tensor([num_local_experts],
dtype=torch.int32,
device="cpu")
torch.distributed.broadcast(num_local_physical_experts,
group=get_ep_group().cpu_group,
group_src=0)
num_local_physical_experts = num_local_physical_experts.item()
new_physical_experts = num_local_physical_experts * new_ep_size
assert self.model_runner.eplb_state is not None
global_expert_load = self.model_runner.eplb_state.rearrange(
self.model_runner.model, execute_shuffle=False)
parallel_config.num_redundant_experts = (
new_physical_experts - global_expert_load.shape[1])
prepare_communication_buffer_for_model(self.model_runner.model)
self.model_runner.model.update_physical_experts_metadata(
num_physical_experts=new_physical_experts,
num_local_physical_experts=num_local_physical_experts)
return global_expert_load
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest) -> None:
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (
cleanup_dist_env_and_memory, get_ep_group)
old_ep_size = get_ep_group().world_size
old_ep_rank = get_ep_group().rank
new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group(
).world_size * get_pp_group().world_size
if new_ep_size < old_ep_size:
self._eplb_before_scale_down(old_ep_size, new_ep_size)
cleanup_dist_env_and_memory()
if reconfig_request.new_data_parallel_rank == \
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
assert old_ep_rank >= new_ep_size
# shutdown
return
self._reconfigure_parallel_config(reconfig_request)
with set_current_vllm_config(self.vllm_config):
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank)
global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)
if new_ep_size > old_ep_size:
assert global_expert_load is not None
self._eplb_after_scale_up(old_ep_size, new_ep_size,
global_expert_load)
def save_sharded_state(
self,
path: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment