Unverified Commit dea26833 authored by Itay Alroy's avatar Itay Alroy Committed by GitHub
Browse files

[1/N] Elastic EP Milestone 2 (#34861)


Signed-off-by: default avatarYongji Wu <wuyongji317@gmail.com>
Signed-off-by: default avatarItay Alroy <ialroy@nvidia.com>
Signed-off-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: default avatarRon Tourgeman <rtourgeman@nvidia.com>
Co-authored-by: default avatarYongji Wu <wuyongji317@gmail.com>
Co-authored-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: default avatarRon Tourgeman <rtourgeman@nvidia.com>
parent 90805ff4
...@@ -29,6 +29,15 @@ PauseMode = Literal["abort", "wait", "keep"] ...@@ -29,6 +29,15 @@ PauseMode = Literal["abort", "wait", "keep"]
# so form part of the external API. # so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error") FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
EEP_NOTIFICATION_CALL_ID = -1
class EEPNotificationType(enum.Enum):
NEW_CORE_ENGINES_INIT_READY = "NEW_CORE_ENGINES_INIT_READY"
NEW_CORE_ENGINES_WEIGHTS_INIT_READY = "NEW_CORE_ENGINES_WEIGHTS_INIT_READY"
RECONFIGURE_FINISHED = "RECONFIGURE_FINISHED"
SHUTDOWN_COMPLETE = "SHUTDOWN_COMPLETE"
class FinishReason(enum.IntEnum): class FinishReason(enum.IntEnum):
""" """
...@@ -235,6 +244,11 @@ class ReconfigureDistributedRequest(msgspec.Struct): ...@@ -235,6 +244,11 @@ class ReconfigureDistributedRequest(msgspec.Struct):
new_data_parallel_rank_local: int new_data_parallel_rank_local: int
new_data_parallel_master_ip: str new_data_parallel_master_ip: str
new_data_parallel_master_port: int new_data_parallel_master_port: int
new_data_parallel_master_port_list: list[int]
new_stateless_world_group_port_list: list[list[int]]
new_stateless_dp_group_port_list: list[list[int]]
new_stateless_ep_group_port_list: list[list[int]]
new_stateless_eplb_group_port_list: list[list[int]]
class ReconfigureRankType(enum.IntEnum): class ReconfigureRankType(enum.IntEnum):
......
...@@ -20,6 +20,7 @@ from vllm.distributed.weight_transfer.base import ( ...@@ -20,6 +20,7 @@ from vllm.distributed.weight_transfer.base import (
) )
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient, StreamingInput from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -647,7 +648,11 @@ class AsyncLLM(EngineClient): ...@@ -647,7 +648,11 @@ class AsyncLLM(EngineClient):
engine_core = self.engine_core engine_core = self.engine_core
output_processor = self.output_processor output_processor = self.output_processor
log_stats = self.log_stats log_stats = self.log_stats
logger_manager = self.logger_manager # We use a mutable list for logger_manager so that it can be updated
# during elastic EP scaling (see scale_elastic_ep) without creating
# a circular reference via self.
self._logger_ref = [self.logger_manager]
logger_ref = self._logger_ref
renderer = self.renderer renderer = self.renderer
chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
...@@ -691,8 +696,8 @@ class AsyncLLM(EngineClient): ...@@ -691,8 +696,8 @@ class AsyncLLM(EngineClient):
# 4) Logging. # 4) Logging.
# TODO(rob): make into a coroutine and launch it in # TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial. # background thread once Prometheus overhead is non-trivial.
if logger_manager: if logger_ref[0]:
logger_manager.record( logger_ref[0].record(
engine_idx=outputs.engine_index, engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
...@@ -976,17 +981,13 @@ class AsyncLLM(EngineClient): ...@@ -976,17 +981,13 @@ class AsyncLLM(EngineClient):
new_data_parallel_size, new_data_parallel_size,
) )
return return
logger.info(
"Waiting for requests to drain before scaling up to %s engines...", if envs.VLLM_ELASTIC_EP_DRAIN_REQUESTS:
new_data_parallel_size, logger.info(
) "VLLM_ELASTIC_EP_DRAIN_REQUESTS is set, "
await self.wait_for_requests_to_drain(drain_timeout) "waiting for requests to drain before scaling"
logger.info( )
"Requests have been drained, proceeding with scale to %s engines", await self.wait_for_requests_to_drain(drain_timeout)
new_data_parallel_size,
)
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
# recreate stat loggers # recreate stat loggers
if new_data_parallel_size > old_data_parallel_size and self.log_stats: if new_data_parallel_size > old_data_parallel_size and self.log_stats:
...@@ -999,6 +1000,18 @@ class AsyncLLM(EngineClient): ...@@ -999,6 +1000,18 @@ class AsyncLLM(EngineClient):
engine_idxs=list(range(new_data_parallel_size)), engine_idxs=list(range(new_data_parallel_size)),
custom_stat_loggers=None, custom_stat_loggers=None,
) )
# Update the mutable ref so output_handler picks up the
# new logger without creating a circular reference via self.
if hasattr(self, "_logger_ref"):
self._logger_ref[0] = self.logger_manager
self.logger_manager.log_engine_initialized()
set_scaling_elastic_ep(True)
try:
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
finally:
set_scaling_elastic_ep(False)
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
......
...@@ -71,6 +71,9 @@ class DPCoordinator: ...@@ -71,6 +71,9 @@ class DPCoordinator:
) )
local_only_eng = dp_size == parallel_config.data_parallel_size_local local_only_eng = dp_size == parallel_config.data_parallel_size_local
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
local_only_eng = False
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host) back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
...@@ -201,6 +204,7 @@ class DPCoordinatorProc: ...@@ -201,6 +204,7 @@ class DPCoordinatorProc:
poller = zmq.Poller() poller = zmq.Poller()
poller.register(publish_front, zmq.POLLIN) poller.register(publish_front, zmq.POLLIN)
poller.register(publish_back, zmq.POLLIN)
poller.register(output_back, zmq.POLLIN) poller.register(output_back, zmq.POLLIN)
last_publish_time = 0 last_publish_time = 0
while True: while True:
...@@ -231,6 +235,22 @@ class DPCoordinatorProc: ...@@ -231,6 +235,22 @@ class DPCoordinatorProc:
events = dict(events) events = dict(events)
wave_state_changed = False wave_state_changed = False
if publish_back in events:
buffer = publish_back.recv()
if buffer == b"\x01":
# NOTE(yongji): newly started engine subscribed
# We need to send READY message here instead of receiving
# SCALE_ELASTIC_EP notification from engine core client
# as SCALE_ELASTIC_EP is only sent when
# new engines finished initialization.
# Subscription message, on the other hand, is sent
# by each engine during initialization
publish_back.send(b"READY")
else:
logger.error(
"DP Coordinator receives unexpected message from engines"
)
if publish_front in events: if publish_front in events:
buffer = publish_front.recv() buffer = publish_front.recv()
if buffer in (b"\x01", b"\x00"): if buffer in (b"\x01", b"\x00"):
...@@ -259,7 +279,6 @@ class DPCoordinatorProc: ...@@ -259,7 +279,6 @@ class DPCoordinatorProc:
# current_wave # current_wave
# we note that 0 is the wave number for the new # we note that 0 is the wave number for the new
# engine # engine
engines_running = False
logger.info( logger.info(
"DPCoordinator scaled up from %s to %s engines", "DPCoordinator scaled up from %s to %s engines",
current_count, current_count,
......
...@@ -17,6 +17,7 @@ from typing import Any, TypeVar, cast ...@@ -17,6 +17,7 @@ from typing import Any, TypeVar, cast
import msgspec import msgspec
import zmq import zmq
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.envs import enable_envs_cache from vllm.envs import enable_envs_cache
...@@ -44,6 +45,8 @@ from vllm.v1.core.kv_cache_utils import ( ...@@ -44,6 +45,8 @@ from vllm.v1.core.kv_cache_utils import (
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ( from vllm.v1.engine import (
EEP_NOTIFICATION_CALL_ID,
EEPNotificationType,
EngineCoreOutput, EngineCoreOutput,
EngineCoreOutputs, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequest,
...@@ -110,6 +113,9 @@ class EngineCore: ...@@ -110,6 +113,9 @@ class EngineCore:
self.available_gpu_memory_for_kv_cache = -1 self.available_gpu_memory_for_kv_cache = -1
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self._eep_scale_up_before_kv_init()
# Setup KV Caches and update CacheConfig after profiling. # Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
vllm_config vllm_config
...@@ -233,12 +239,10 @@ class EngineCore: ...@@ -233,12 +239,10 @@ class EngineCore:
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs) has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
if has_kv_cache: if has_kv_cache:
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1": if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
dp_group = getattr(self, "dp_group", None) # NOTE(yongji): should already be set
assert dp_group is not None # during _eep_scale_up_before_kv_init
self.available_gpu_memory_for_kv_cache = ( assert self.available_gpu_memory_for_kv_cache > 0
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
)
available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len( available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
kv_cache_specs kv_cache_specs
) )
...@@ -752,11 +756,22 @@ class EngineCore: ...@@ -752,11 +756,22 @@ class EngineCore:
self.structured_output_manager.grammar_init(req) self.structured_output_manager.grammar_init(req)
return req, request.current_wave return req, request.current_wave
def _eep_scale_up_before_kv_init(self):
raise NotImplementedError
def _eep_send_engine_core_notification(
self,
notification_type: EEPNotificationType,
vllm_config: VllmConfig | None = None,
):
raise NotImplementedError
class EngineCoreProc(EngineCore): class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process.""" """ZMQ-wrapper for running EngineCore in background process."""
ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD" ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
addresses: EngineZmqAddresses
@instrument(span_name="EngineCoreProc init") @instrument(span_name="EngineCoreProc init")
def __init__( def __init__(
...@@ -807,6 +822,13 @@ class EngineCoreProc(EngineCore): ...@@ -807,6 +822,13 @@ class EngineCoreProc(EngineCore):
# and "hybrid" LB modes. # and "hybrid" LB modes.
self.publish_dp_lb_stats = internal_dp_balancing self.publish_dp_lb_stats = internal_dp_balancing
self.addresses = addresses
self.process_input_queue_block = True
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self._eep_send_engine_core_notification(
EEPNotificationType.NEW_CORE_ENGINES_INIT_READY,
vllm_config=vllm_config,
)
self._init_data_parallel(vllm_config) self._init_data_parallel(vllm_config)
super().__init__( super().__init__(
...@@ -1119,8 +1141,14 @@ class EngineCoreProc(EngineCore): ...@@ -1119,8 +1141,14 @@ class EngineCoreProc(EngineCore):
if logger.isEnabledFor(DEBUG): if logger.isEnabledFor(DEBUG):
logger.debug("EngineCore waiting for work.") logger.debug("EngineCore waiting for work.")
waited = True waited = True
req = self.input_queue.get() block = self.process_input_queue_block
self._handle_client_request(*req) try:
req = self.input_queue.get(block=block)
self._handle_client_request(*req)
except queue.Empty:
break
if not block:
break
if waited: if waited:
logger.debug("EngineCore loop active.") logger.debug("EngineCore loop active.")
...@@ -1290,6 +1318,11 @@ class EngineCoreProc(EngineCore): ...@@ -1290,6 +1318,11 @@ class EngineCoreProc(EngineCore):
for input_socket, _ in poller.poll(): for input_socket, _ in poller.poll():
# (RequestType, RequestData) # (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(copy=False) type_frame, *data_frames = input_socket.recv_multipart(copy=False)
# NOTE(yongji): ignore READY message sent by DP coordinator
# that is used to notify newly started engines
if type_frame.buffer == b"READY":
assert input_socket == coord_socket
continue
request_type = EngineCoreRequestType(bytes(type_frame.buffer)) request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data. # Deserialize the request data.
...@@ -1488,6 +1521,10 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1488,6 +1521,10 @@ class DPEngineCoreProc(EngineCoreProc):
self.current_wave = 0 self.current_wave = 0
self.last_counts = (0, 0) self.last_counts = (0, 0)
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
self.eep_scaling_state: ElasticEPScalingState | None = None
# Initialize the engine. # Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__( super().__init__(
...@@ -1511,7 +1548,9 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1511,7 +1548,9 @@ class DPEngineCoreProc(EngineCoreProc):
assert 0 <= local_dp_rank <= dp_rank < dp_size assert 0 <= local_dp_rank <= dp_rank < dp_size
self.dp_rank = dp_rank self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() self.dp_group, self.dp_store = (
vllm_config.parallel_config.stateless_init_dp_group(return_store=True)
)
def shutdown(self): def shutdown(self):
super().shutdown() super().shutdown()
...@@ -1574,7 +1613,12 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1574,7 +1613,12 @@ class DPEngineCoreProc(EngineCoreProc):
# 1) Poll the input queue until there is work to do. # 1) Poll the input queue until there is work to do.
self._process_input_queue() self._process_input_queue()
# 2) Step the engine core. if self.eep_scaling_state is not None:
_ = self.eep_scaling_state.progress()
if self.eep_scaling_state.is_complete():
self.process_input_queue_block = True
self.eep_scaling_state = None
executed = self._process_engine_step() executed = self._process_engine_step()
self._maybe_publish_request_counts() self._maybe_publish_request_counts()
...@@ -1624,54 +1668,129 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1624,54 +1668,129 @@ class DPEngineCoreProc(EngineCoreProc):
def reinitialize_distributed( def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest self, reconfig_request: ReconfigureDistributedRequest
) -> None: ) -> None:
stateless_destroy_torch_distributed_process_group(self.dp_group) from copy import deepcopy
self.shutdown()
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
parallel_config = self.vllm_config.parallel_config
old_dp_size = parallel_config.data_parallel_size new_parallel_config = deepcopy(self.vllm_config.parallel_config)
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size old_dp_size = new_parallel_config.data_parallel_size
if reconfig_request.new_data_parallel_rank != -1: new_parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank if (
# local rank specifies device visibility, it should not be changed reconfig_request.new_data_parallel_rank
assert ( != ReconfigureRankType.KEEP_CURRENT_RANK
reconfig_request.new_data_parallel_rank_local ):
== ReconfigureRankType.KEEP_CURRENT_RANK new_parallel_config.data_parallel_rank = (
) reconfig_request.new_data_parallel_rank
parallel_config.data_parallel_master_ip = ( )
new_parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip reconfig_request.new_data_parallel_master_ip
) )
parallel_config.data_parallel_master_port = ( new_parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port reconfig_request.new_data_parallel_master_port
) )
if reconfig_request.new_data_parallel_rank != -2: new_parallel_config._data_parallel_master_port_list = (
self.dp_rank = parallel_config.data_parallel_rank reconfig_request.new_data_parallel_master_port_list
self.dp_group = parallel_config.stateless_init_dp_group()
reconfig_request.new_data_parallel_master_port = (
parallel_config.data_parallel_master_port
) )
self.model_executor.reinitialize_distributed(reconfig_request) is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size
if reconfig_request.new_data_parallel_size > old_dp_size: is_shutdown = (
assert self.available_gpu_memory_for_kv_cache > 0
# pass available_gpu_memory_for_kv_cache from existing
# engine-cores to new engine-cores so they can directly
# use it in _initialize_kv_caches() rather than profiling.
ParallelConfig.sync_kv_cache_memory_size(
self.dp_group, self.available_gpu_memory_for_kv_cache
)
# NOTE(yongji): newly joined workers require dummy_run even
# CUDA graph is not used
self.model_executor.collective_rpc("compile_or_warm_up_model")
if (
reconfig_request.new_data_parallel_rank reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
): )
self.shutdown()
logger.info("DPEngineCoreProc %s shutdown", self.dp_rank) self.eep_scaling_state = ElasticEPScalingState(
model_executor=self.model_executor,
engine_core=self,
vllm_config=self.vllm_config,
new_parallel_config=new_parallel_config,
worker_type="removing" if is_shutdown else "existing",
scale_type="scale_down" if is_scale_down else "scale_up",
reconfig_request=reconfig_request,
)
self.process_input_queue_block = False
logger.info(
"[Elastic EP] Received reconfiguration request and starting scaling up/down"
)
def _eep_send_engine_core_notification(
self,
notification_type: EEPNotificationType,
vllm_config: VllmConfig | None = None,
):
"""
Send notifications to EngineCoreClient, which can then forward
the notifications to other engine core processes. It is used for:
1) In scale up: new core engines to notify exisiting core engines
that they are ready;
2) In scale down: removing core engines to notify EngineCoreClient
so EngineCoreClient can release their ray placement groups;
3) Both scale up/down: to notify EngineCoreClient that exisiting
core engines have already switched to the new parallel setup.
"""
if vllm_config is None:
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
else: else:
logger.info( dp_rank = vllm_config.parallel_config.data_parallel_rank
"Distributed environment reinitialized for DP rank %s", self.dp_rank notification_data = (notification_type.value, dp_rank)
outputs = EngineCoreOutputs(
utility_output=UtilityOutput(
call_id=EEP_NOTIFICATION_CALL_ID,
result=UtilityResult(notification_data),
) )
)
outputs.engine_index = self.engine_index
if hasattr(self, "output_thread") and self.output_thread.is_alive():
self.output_queue.put_nowait((0, outputs))
else:
encoder = MsgpackEncoder()
with (
zmq.Context() as ctx,
make_zmq_socket(
ctx, self.addresses.outputs[0], zmq.PUSH, linger=4000
) as socket,
):
socket.send_multipart(encoder.encode(outputs))
def eep_handle_engine_core_notification(
self, notification_type: str | EEPNotificationType
):
"""
Handle notification received from EngineCoreClient
(forwarded from new core engines).
"""
assert self.eep_scaling_state is not None
if isinstance(notification_type, str):
notification_type = EEPNotificationType(notification_type)
self.eep_scaling_state.handle_notification(notification_type)
def _eep_scale_up_before_kv_init(self):
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
self.eep_scaling_state = ElasticEPScalingState(
model_executor=self.model_executor,
engine_core=self,
vllm_config=self.vllm_config,
new_parallel_config=self.vllm_config.parallel_config,
worker_type="new",
scale_type="scale_up",
reconfig_request=None,
)
self.model_executor.collective_rpc("init_device")
self.model_executor.collective_rpc("load_model")
self._eep_send_engine_core_notification(
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("receive_weights",)
)
self.available_gpu_memory_for_kv_cache = (
ParallelConfig.sync_kv_cache_memory_size(self.dp_group, -1)
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("prepare_new_worker",)
)
self.process_input_queue_block = False
class EngineCoreActorMixin: class EngineCoreActorMixin:
......
...@@ -28,11 +28,12 @@ from vllm.tracing import instrument ...@@ -28,11 +28,12 @@ from vllm.tracing import instrument
from vllm.utils.async_utils import in_loop from vllm.utils.async_utils import in_loop
from vllm.utils.network_utils import ( from vllm.utils.network_utils import (
close_sockets, close_sockets,
get_open_port,
get_open_zmq_inproc_path, get_open_zmq_inproc_path,
make_zmq_socket, make_zmq_socket,
) )
from vllm.v1.engine import ( from vllm.v1.engine import (
EEP_NOTIFICATION_CALL_ID,
EEPNotificationType,
EngineCoreOutputs, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestType,
...@@ -47,6 +48,7 @@ from vllm.v1.engine.exceptions import EngineDeadError ...@@ -47,6 +48,7 @@ from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.engine.utils import ( from vllm.v1.engine.utils import (
CoreEngineActorManager, CoreEngineActorManager,
CoreEngineProcManager, CoreEngineProcManager,
get_engine_zmq_addresses,
launch_core_engines, launch_core_engines,
) )
from vllm.v1.executor import Executor from vllm.v1.executor import Executor
...@@ -445,6 +447,63 @@ class BackgroundResources: ...@@ -445,6 +447,63 @@ class BackgroundResources:
raise EngineDeadError() raise EngineDeadError()
@dataclass
class ElasticScalingCache:
existing_core_engines: list[EngineIdentity]
num_new_core_engines: int
pending_notifications: dict[EEPNotificationType, set[int]]
def allocate_stateless_group_ports(parallel_config, new_data_parallel_size: int):
"""
Allocate stateless group ports for elastic EP.
"""
from vllm.utils.network_utils import get_open_ports_list
assert parallel_config.enable_elastic_ep, "Elastic EP must be enabled"
world_size = parallel_config.world_size
new_world_size_across_dp = world_size * new_data_parallel_size
num_world_groups = 1
num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size)
num_ep_groups = max(
1,
new_world_size_across_dp
// (new_data_parallel_size * parallel_config.tensor_parallel_size),
)
num_eplb_groups = num_ep_groups
total_ports_needed = (
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
) * 3 + 5
all_ports = get_open_ports_list(total_ports_needed)
new_data_parallel_master_port_list = all_ports[-5:]
all_ports = all_ports[:-5]
new_stateless_world_group_port_list = [
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
]
start_idx = num_world_groups * 3
new_stateless_dp_group_port_list = [
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
]
start_idx += num_dp_groups * 3
new_stateless_ep_group_port_list = [
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
]
start_idx += num_ep_groups * 3
new_stateless_eplb_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
]
parallel_config._stateless_world_group_port_list = (
new_stateless_world_group_port_list
)
parallel_config._stateless_dp_group_port_list = new_stateless_dp_group_port_list
parallel_config._stateless_ep_group_port_list = new_stateless_ep_group_port_list
parallel_config._stateless_eplb_group_port_list = new_stateless_eplb_group_port_list
parallel_config.data_parallel_master_port = new_data_parallel_master_port_list.pop()
parallel_config._data_parallel_master_port_list = new_data_parallel_master_port_list
class MPClient(EngineCoreClient): class MPClient(EngineCoreClient):
""" """
MPClient: base client for multi-proc EngineCore. MPClient: base client for multi-proc EngineCore.
...@@ -491,32 +550,37 @@ class MPClient(EngineCoreClient): ...@@ -491,32 +550,37 @@ class MPClient(EngineCoreClient):
input_address = client_addresses["input_address"] input_address = client_addresses["input_address"]
output_address = client_addresses["output_address"] output_address = client_addresses["output_address"]
self.stats_update_address = client_addresses.get("stats_update_address") self.stats_update_address = client_addresses.get("stats_update_address")
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True
)
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.PULL
)
else: else:
# Engines are managed by this client. # Engines are managed by this client.
with launch_core_engines(vllm_config, executor_class, log_stats) as ( addresses = get_engine_zmq_addresses(vllm_config)
engine_manager, self.input_socket = self.resources.input_socket = make_zmq_socket(
coordinator, self.ctx, addresses.inputs[0], zmq.ROUTER, bind=True
)
self.resources.output_socket = make_zmq_socket(
self.ctx, addresses.outputs[0], zmq.PULL
)
with launch_core_engines(
vllm_config,
executor_class,
log_stats,
addresses, addresses,
): ) as (engine_manager, coordinator, addresses):
self.resources.coordinator = coordinator self.resources.coordinator = coordinator
self.resources.engine_manager = engine_manager self.resources.engine_manager = engine_manager
(input_address,) = addresses.inputs
(output_address,) = addresses.outputs
self.stats_update_address = addresses.frontend_stats_publish_address self.stats_update_address = addresses.frontend_stats_publish_address
if coordinator is not None: if coordinator is not None:
assert self.stats_update_address == ( assert self.stats_update_address == (
coordinator.get_stats_publish_address() coordinator.get_stats_publish_address()
) )
# Create input and output sockets.
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True
)
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.PULL
)
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_index dp_rank = parallel_config.data_parallel_index
...@@ -877,6 +941,10 @@ class AsyncMPClient(MPClient): ...@@ -877,6 +941,10 @@ class AsyncMPClient(MPClient):
output_socket = resources.output_socket output_socket = resources.output_socket
assert output_socket is not None assert output_socket is not None
notification_callback_handler: (
Callable[[AsyncMPClient, Sequence[Any]], Any] | None
) = getattr(self.__class__, "eep_process_engine_core_notification", None)
async def process_outputs_socket(): async def process_outputs_socket():
try: try:
while True: while True:
...@@ -884,7 +952,26 @@ class AsyncMPClient(MPClient): ...@@ -884,7 +952,26 @@ class AsyncMPClient(MPClient):
resources.validate_alive(frames) resources.validate_alive(frames)
outputs: EngineCoreOutputs = decoder.decode(frames) outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output: if outputs.utility_output:
_process_utility_output(outputs.utility_output, utility_results) if (
outputs.utility_output.call_id == EEP_NOTIFICATION_CALL_ID
and notification_callback_handler is not None
):
assert _self_ref is not None
_self = _self_ref()
if not _self:
return
if outputs.utility_output.result is None:
continue
notification_data = outputs.utility_output.result.result
assert isinstance(notification_data, Sequence)
assert len(notification_data) == 2
asyncio.create_task(
notification_callback_handler(_self, notification_data)
)
else:
_process_utility_output(
outputs.utility_output, utility_results
)
continue continue
if output_handler is not None: if output_handler is not None:
...@@ -1081,6 +1168,8 @@ class DPAsyncMPClient(AsyncMPClient): ...@@ -1081,6 +1168,8 @@ class DPAsyncMPClient(AsyncMPClient):
# Used only by DPLBAsyncMPClient subclass. # Used only by DPLBAsyncMPClient subclass.
self.lb_engines: list[list[int]] = [[0, 0] for _ in self.core_engines] self.lb_engines: list[list[int]] = [[0, 0] for _ in self.core_engines]
self.eep_scaling_cache: ElasticScalingCache | None = None
self.first_req_sock_addr = get_open_zmq_inproc_path() self.first_req_sock_addr = get_open_zmq_inproc_path()
self.first_req_send_socket = self.resources.first_req_send_socket = ( self.first_req_send_socket = self.resources.first_req_send_socket = (
make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True) make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True)
...@@ -1101,12 +1190,6 @@ class DPAsyncMPClient(AsyncMPClient): ...@@ -1101,12 +1190,6 @@ class DPAsyncMPClient(AsyncMPClient):
assert self.stats_update_address is not None assert self.stats_update_address is not None
stats_addr: str = self.stats_update_address stats_addr: str = self.stats_update_address
assert len(self.engine_ranks_managed) > 0 assert len(self.engine_ranks_managed) > 0
# NOTE: running and waiting counts are all global from
# the Coordinator include all global EngineCores. This
# slice includes just the cores managed by this client.
count_slice = slice(
self.engine_ranks_managed[0], self.engine_ranks_managed[-1] + 1
)
async def run_engine_stats_update_task(): async def run_engine_stats_update_task():
with ( with (
...@@ -1145,6 +1228,29 @@ class DPAsyncMPClient(AsyncMPClient): ...@@ -1145,6 +1228,29 @@ class DPAsyncMPClient(AsyncMPClient):
): ):
# Extract new engine count from the decoded message # Extract new engine count from the decoded message
new_engine_count = decoded[1] new_engine_count = decoded[1]
# Update engine_ranks_managed and count_slice
parallel_config = self.vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
assert dp_rank == 0
assert dp_size == new_engine_count
assert not (
parallel_config.data_parallel_hybrid_lb
or parallel_config.data_parallel_external_lb
)
num_ranks = dp_size
self.engine_ranks_managed = list(
range(dp_rank, dp_rank + num_ranks)
)
if len(self.lb_engines) < new_engine_count:
self.lb_engines = self.lb_engines + [
[0, 0]
for _ in range(
new_engine_count - len(self.lb_engines)
)
]
else:
self.lb_engines = self.lb_engines[:new_engine_count]
# Send scale up notification to coordinator # Send scale up notification to coordinator
scale_msg = msgspec.msgpack.encode( scale_msg = msgspec.msgpack.encode(
("SCALE_ELASTIC_EP", new_engine_count) ("SCALE_ELASTIC_EP", new_engine_count)
...@@ -1178,6 +1284,11 @@ class DPAsyncMPClient(AsyncMPClient): ...@@ -1178,6 +1284,11 @@ class DPAsyncMPClient(AsyncMPClient):
self.current_wave = wave self.current_wave = wave
self.engines_running = running self.engines_running = running
if counts is not None: if counts is not None:
# Running and waiting counts are global from the
# Coordinator including all EngineCores. Slice to get
# just the cores managed by this client.
ranks = self.engine_ranks_managed
count_slice = slice(ranks[0], ranks[-1] + 1)
sliced_counts = counts[count_slice] sliced_counts = counts[count_slice]
self.lb_engines = sliced_counts self.lb_engines = sliced_counts
logger.debug( logger.debug(
...@@ -1287,6 +1398,67 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ...@@ -1287,6 +1398,67 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
for req_id in outputs.finished_requests: for req_id in outputs.finished_requests:
self.reqs_in_flight.pop(req_id, None) self.reqs_in_flight.pop(req_id, None)
@staticmethod
async def eep_process_engine_core_notification(
self: "DPLBAsyncMPClient", notification_data: tuple[str, int]
):
cache = self.eep_scaling_cache
notification_type_str, dp_rank = notification_data
try:
notification_type = EEPNotificationType(notification_type_str)
except ValueError as e:
raise ValueError(
f"Unknown EEP notification type: {notification_type_str}"
) from e
if notification_type == EEPNotificationType.RECONFIGURE_FINISHED:
from vllm.v1.engine import UtilityResult
# NOTE(yongji): process a dummy UtilityOutput to resolve the future
# awaited in _eep_wait_for_setup_switch_complete(), signaling that
# all engine cores have completed reconfiguration.
dummy_output = UtilityOutput(
call_id=EEP_NOTIFICATION_CALL_ID, result=UtilityResult(None)
)
_process_utility_output(dummy_output, self.utility_results)
return
assert cache is not None
if notification_type not in cache.pending_notifications:
cache.pending_notifications[notification_type] = set()
if dp_rank in cache.pending_notifications[notification_type]:
raise ValueError(
f"Duplicate notification {notification_type} from dp_rank {dp_rank}"
)
cache.pending_notifications[notification_type].add(dp_rank)
if len(cache.pending_notifications[notification_type]) >= abs(
cache.num_new_core_engines
):
if notification_type == EEPNotificationType.SHUTDOWN_COMPLETE:
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
assert cache.num_new_core_engines < 0
old_dp_size = len(cache.existing_core_engines)
new_dp_size = old_dp_size + cache.num_new_core_engines
self.resources.engine_manager.scale_down_elastic_ep(
old_dp_size, new_dp_size
)
else:
await asyncio.gather(
*[
self._call_utility_async(
"eep_handle_engine_core_notification",
notification_type,
engine=engine,
)
for engine in cache.existing_core_engines
]
)
cache.pending_notifications[notification_type] = set()
if notification_type in [
EEPNotificationType.SHUTDOWN_COMPLETE,
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY,
]:
self.eep_scaling_cache = None
async def abort_requests_async(self, request_ids: list[str]) -> None: async def abort_requests_async(self, request_ids: list[str]) -> None:
if not request_ids or self.resources.engine_dead: if not request_ids or self.resources.engine_dead:
return return
...@@ -1333,6 +1505,20 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ...@@ -1333,6 +1505,20 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
cur_data_parallel_size, new_data_parallel_size cur_data_parallel_size, new_data_parallel_size
) )
async def _eep_wait_for_setup_switch_complete(self) -> None:
"""
Wait for core engines to switch to the new setup.
In eep_process_engine_core_notification(), a dummy UtilityOutput with
EEP_NOTIFICATION_CALL_ID will be set when RECONFIGURE_FINISHED
notification is received from engine 0. We create a future with
that call_id and wait for it to be resolved.
"""
future = asyncio.get_running_loop().create_future()
self.utility_results[EEP_NOTIFICATION_CALL_ID] = future
self._ensure_output_queue_task()
await future
async def _scale_up_elastic_ep( async def _scale_up_elastic_ep(
self, cur_data_parallel_size: int, new_data_parallel_size: int self, cur_data_parallel_size: int, new_data_parallel_size: int
) -> None: ) -> None:
...@@ -1340,38 +1526,57 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ...@@ -1340,38 +1526,57 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
and reconfiguring existing ones.""" and reconfiguring existing ones."""
cur_data_parallel_size = len(self.core_engines) cur_data_parallel_size = len(self.core_engines)
# Phase 1: Send reconfigure messages to all existing engines and wait self.eep_scaling_cache = ElasticScalingCache(
# for them to be sent existing_core_engines=self.core_engines.copy(),
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
pending_notifications=dict(),
)
parallel_config = self.vllm_config.parallel_config
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
# Phase 1: Send reconfig messages to existing engines
reconfig_futures = [] reconfig_futures = []
self.vllm_config.parallel_config.data_parallel_master_port = get_open_port()
for engine in self.core_engines: for engine in self.core_engines:
reconfig_request = ReconfigureDistributedRequest( reconfig_request = ReconfigureDistributedRequest(
new_data_parallel_size=new_data_parallel_size, new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip, new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port, new_data_parallel_master_port=parallel_config.data_parallel_master_port,
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
) )
coro = self._call_utility_async( coro = self._call_utility_async(
"reinitialize_distributed", reconfig_request, engine=engine "reinitialize_distributed", reconfig_request, engine=engine
) )
reconfig_futures.append(asyncio.create_task(coro)) reconfig_futures.append(asyncio.create_task(coro))
logger.info("All reconfigure messages sent, starting engine creation") # Phase 2: Create new engines
# Phase 2: Create new engines now that reconfig messages have been sent
# self.resources.engine_manager is guaranteed to be
# CoreEngineActorManager for RayDPClient
assert isinstance(self.resources.engine_manager, CoreEngineActorManager) assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
self.resources.engine_manager.scale_up_elastic_ep( parallel_config.eplb_config.num_redundant_experts = 0
self.vllm_config, new_data_parallel_size start_new_worker_future = asyncio.to_thread(
self.resources.engine_manager.scale_up_elastic_ep,
self.vllm_config,
new_data_parallel_size,
) )
wait_future = self._eep_wait_for_setup_switch_complete()
# Phase 3: Wait for new engines to be created
# and reconfig messages to be received
await asyncio.gather(start_new_worker_future, *reconfig_futures)
logger.info("[Elastic EP] Successfully started new engines")
# Create new CoreEngine objects for the new engines # Create new CoreEngine objects for the new engines
new_engine_identities = set() new_engine_identities = set()
for i in range(cur_data_parallel_size, new_data_parallel_size): for i in range(cur_data_parallel_size, new_data_parallel_size):
new_engine = i.to_bytes(2, "little") new_engine = i.to_bytes(2, "little")
self.core_engines.append(new_engine) self.core_engines.append(new_engine)
# NOTE(yongji): we don't update lb_engines here,
# we let run_engine_stats_update_task to update it.
new_engine_identities.add(new_engine) new_engine_identities.add(new_engine)
# Wait for ready messages from new engines on the input socket # Wait for ready messages from new engines on the input socket
...@@ -1387,10 +1592,11 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ...@@ -1387,10 +1592,11 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
identity, _ = sync_input_socket.recv_multipart() identity, _ = sync_input_socket.recv_multipart()
new_engine_identities.discard(identity) new_engine_identities.discard(identity)
# Phase 3: Wait for all existing engines to complete reconfiguration # NOTE(yongji): Before we schedule any requests on the new workers,
logger.info("Waiting for existing engines to complete reconfiguration") # we should wait for them to switch to the new setup.
await asyncio.gather(*reconfig_futures) await wait_future
# Update the parallel config
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
# Notify coordinator about scale up through existing # Notify coordinator about scale up through existing
# stats_update_task connection # stats_update_task connection
self._ensure_stats_update_task() self._ensure_stats_update_task()
...@@ -1399,8 +1605,6 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ...@@ -1399,8 +1605,6 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
) )
await self.first_req_send_socket.send(scale_up_marker) await self.first_req_send_socket.send(scale_up_marker)
# Update the parallel config
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
logger.info( logger.info(
"[Elastic EP] Scale up completed, new data parallel size: %s", "[Elastic EP] Scale up completed, new data parallel size: %s",
new_data_parallel_size, new_data_parallel_size,
...@@ -1413,7 +1617,14 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ...@@ -1413,7 +1617,14 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
reconfiguring existing engine cores.""" reconfiguring existing engine cores."""
cur_data_parallel_size = len(self.core_engines) cur_data_parallel_size = len(self.core_engines)
self.vllm_config.parallel_config.data_parallel_master_port = get_open_port() self.eep_scaling_cache = ElasticScalingCache(
existing_core_engines=self.core_engines.copy(),
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
pending_notifications=dict(),
)
parallel_config = self.vllm_config.parallel_config
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
reconfig_futures = [] reconfig_futures = []
for cur_dp_rank, engine in enumerate(self.core_engines): for cur_dp_rank, engine in enumerate(self.core_engines):
...@@ -1421,8 +1632,13 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ...@@ -1421,8 +1632,13 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
new_data_parallel_size=new_data_parallel_size, new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip, new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port, new_data_parallel_master_port=parallel_config.data_parallel_master_port,
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
) )
if cur_dp_rank >= new_data_parallel_size: if cur_dp_rank >= new_data_parallel_size:
reconfig_request.new_data_parallel_rank = ( reconfig_request.new_data_parallel_rank = (
...@@ -1433,23 +1649,24 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ...@@ -1433,23 +1649,24 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
) )
reconfig_futures.append(asyncio.create_task(coro)) reconfig_futures.append(asyncio.create_task(coro))
for _ in range(new_data_parallel_size, cur_data_parallel_size): # NOTE(yongji): Immediately stop sending requests to the removing engines.
self.core_engines.pop() self.core_engines = self.core_engines[:new_data_parallel_size]
self.lb_engines = self.lb_engines[:new_data_parallel_size]
wait_future = self._eep_wait_for_setup_switch_complete()
await asyncio.gather(*reconfig_futures) await asyncio.gather(*reconfig_futures)
assert isinstance(self.resources.engine_manager, CoreEngineActorManager) self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
self.resources.engine_manager.scale_down_elastic_ep(
cur_data_parallel_size, new_data_parallel_size
)
self._ensure_stats_update_task() self._ensure_stats_update_task()
scale_down_marker = msgspec.msgpack.encode( scale_down_marker = msgspec.msgpack.encode(
("SCALE_ELASTIC_EP", new_data_parallel_size) ("SCALE_ELASTIC_EP", new_data_parallel_size)
) )
await self.first_req_send_socket.send(scale_down_marker) await self.first_req_send_socket.send(scale_down_marker)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size # NOTE(yongji): Unlike scaling up,
# here we don't actually need to wait for the setup switch to complete.
# We may want to remove it in the future.
await wait_future
logger.info( logger.info(
"[Elastic EP] Scale down completed, new data parallel size: %s", "[Elastic EP] Scale down completed, new data parallel size: %s",
new_data_parallel_size, new_data_parallel_size,
......
...@@ -277,6 +277,8 @@ class CoreEngineActorManager: ...@@ -277,6 +277,8 @@ class CoreEngineActorManager:
else: else:
ray.init() ray.init()
vllm_config.parallel_config.allocate_elastic_ep_ports()
if placement_groups is not None: if placement_groups is not None:
assert local_dp_ranks is not None, ( assert local_dp_ranks is not None, (
"local_dp_ranks must be provided if placement_groups is provided" "local_dp_ranks must be provided if placement_groups is provided"
...@@ -584,6 +586,8 @@ class CoreEngineActorManager: ...@@ -584,6 +586,8 @@ class CoreEngineActorManager:
node_ip = node.node_ip node_ip = node.node_ip
node_id = node.node_id node_id = node.node_id
if device_str not in available_resources[node_id]:
continue
available_gpus = int(available_resources[node_id][device_str]) available_gpus = int(available_resources[node_id][device_str])
# Get total GPUs on this node from the node's resources # Get total GPUs on this node from the node's resources
...@@ -773,26 +777,15 @@ class CoreEngineActorManager: ...@@ -773,26 +777,15 @@ class CoreEngineActorManager:
ray.util.remove_placement_group(pg) ray.util.remove_placement_group(pg)
@contextlib.contextmanager def get_engine_zmq_addresses(
def launch_core_engines(
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
num_api_servers: int = 1, num_api_servers: int = 1,
) -> Iterator[ ) -> EngineZmqAddresses:
tuple[ """Allocate ZMQ addresses for engine-client communication."""
CoreEngineProcManager | CoreEngineActorManager | None,
DPCoordinator | None,
EngineZmqAddresses,
]
]:
"""Launch engine and DP coordinator processes as needed."""
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local local_engine_count = parallel_config.data_parallel_size_local
local_start_index = parallel_config.data_parallel_rank_local local_start_index = parallel_config.data_parallel_rank_local
dp_rank = parallel_config.data_parallel_rank dp_size = parallel_config.data_parallel_size
host = parallel_config.data_parallel_master_ip host = parallel_config.data_parallel_master_ip
local_engines_only = parallel_config.local_engines_only local_engines_only = parallel_config.local_engines_only
...@@ -806,9 +799,11 @@ def launch_core_engines( ...@@ -806,9 +799,11 @@ def launch_core_engines(
client_local_only = ( client_local_only = (
offline_mode or local_engines_only or (local_engine_count == dp_size) offline_mode or local_engines_only or (local_engine_count == dp_size)
) )
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
client_local_only = False
# Set up input and output addresses. return EngineZmqAddresses(
addresses = EngineZmqAddresses(
inputs=[ inputs=[
get_engine_client_zmq_addr(client_local_only, host) get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers) for _ in range(num_api_servers)
...@@ -819,6 +814,33 @@ def launch_core_engines( ...@@ -819,6 +814,33 @@ def launch_core_engines(
], ],
) )
@contextlib.contextmanager
def launch_core_engines(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
addresses: EngineZmqAddresses,
num_api_servers: int = 1,
) -> Iterator[
tuple[
CoreEngineProcManager | CoreEngineActorManager | None,
DPCoordinator | None,
EngineZmqAddresses,
]
]:
"""Launch engine and DP coordinator processes as needed."""
parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local
local_start_index = parallel_config.data_parallel_rank_local
dp_rank = parallel_config.data_parallel_rank
host = parallel_config.data_parallel_master_ip
local_engines_only = parallel_config.local_engines_only
offline_mode = local_start_index is not None
# Run the DP Coordinator process with rank 0 when in online DP mode. # Run the DP Coordinator process with rank 0 when in online DP mode.
# The coordinator is needed for: # The coordinator is needed for:
# 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing # 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
...@@ -885,6 +907,10 @@ def launch_core_engines( ...@@ -885,6 +907,10 @@ def launch_core_engines(
# will be False. # will be False.
handshake_local_only = offline_mode or local_engine_count == dp_size handshake_local_only = offline_mode or local_engine_count == dp_size
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
handshake_local_only = False
handshake_address = get_engine_client_zmq_addr( handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port handshake_local_only, host, parallel_config.data_parallel_rpc_port
) )
......
...@@ -38,6 +38,7 @@ from vllm.distributed.parallel_state import ( ...@@ -38,6 +38,7 @@ from vllm.distributed.parallel_state import (
get_pcp_group, get_pcp_group,
get_pp_group, get_pp_group,
get_tp_group, get_tp_group,
model_parallel_is_initialized,
) )
from vllm.envs import enable_envs_cache from vllm.envs import enable_envs_cache
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -580,17 +581,20 @@ class WorkerProc: ...@@ -580,17 +581,20 @@ class WorkerProc:
) )
self.async_output_copy_thread.start() self.async_output_copy_thread.start()
# Initialize device
self.worker.init_device()
# Set process title and log prefix
self.setup_proc_title_and_log_prefix( self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel enable_ep=vllm_config.parallel_config.enable_expert_parallel
) )
# Load model # Load model
self._init_message_queues(input_shm_handle, vllm_config) self._init_message_queues(input_shm_handle, vllm_config)
self.worker.load_model() is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
if not is_eep_new_worker:
self.worker.init_device()
# Update process title now that parallel groups are initialized
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
self.worker.load_model()
# Enable environment variable cache (e.g. assume no more # Enable environment variable cache (e.g. assume no more
# environment variable overrides after this point) # environment variable overrides after this point)
...@@ -885,6 +889,13 @@ class WorkerProc: ...@@ -885,6 +889,13 @@ class WorkerProc:
@staticmethod @staticmethod
def setup_proc_title_and_log_prefix(enable_ep: bool) -> None: def setup_proc_title_and_log_prefix(enable_ep: bool) -> None:
# Check if parallel groups are initialized first
if not model_parallel_is_initialized():
# Parallel groups not yet initialized, use default process name
set_process_title(name="Worker")
decorate_logs("Worker")
return
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
dp_rank = get_dp_group().rank_in_group dp_rank = get_dp_group().rank_in_group
pp_size = get_pp_group().world_size pp_size = get_pp_group().world_size
......
...@@ -382,8 +382,10 @@ class RayDistributedExecutor(Executor): ...@@ -382,8 +382,10 @@ class RayDistributedExecutor(Executor):
all_kwargs.append(kwargs) all_kwargs.append(kwargs)
self.collective_rpc("init_worker", args=(all_kwargs,)) self.collective_rpc("init_worker", args=(all_kwargs,))
self.collective_rpc("init_device") is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
self.collective_rpc("load_model") if not is_eep_new_worker:
self.collective_rpc("init_device")
self.collective_rpc("load_model")
for pp_rank in range(self.parallel_config.pipeline_parallel_size): for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([]) self.pp_tp_workers.append([])
......
...@@ -14,7 +14,6 @@ import vllm.envs as envs ...@@ -14,7 +14,6 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.serial_utils import run_method from vllm.v1.serial_utils import run_method
...@@ -43,9 +42,11 @@ class UniProcExecutor(Executor): ...@@ -43,9 +42,11 @@ class UniProcExecutor(Executor):
max_workers=1, thread_name_prefix="WorkerAsyncOutput" max_workers=1, thread_name_prefix="WorkerAsyncOutput"
) )
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
self.driver_worker.init_worker(all_kwargs=[kwargs]) self.driver_worker.init_worker(all_kwargs=[kwargs])
self.driver_worker.init_device() if not is_eep_new_worker:
self.driver_worker.load_model() self.driver_worker.init_device()
self.driver_worker.load_model()
def _distributed_args(self) -> tuple[str, int, int]: def _distributed_args(self) -> tuple[str, int, int]:
"""Return (distributed_init_method, rank, local_rank).""" """Return (distributed_init_method, rank, local_rank)."""
...@@ -122,16 +123,6 @@ class UniProcExecutor(Executor): ...@@ -122,16 +123,6 @@ class UniProcExecutor(Executor):
# it's running. # it's running.
return return
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
self.driver_worker.reinitialize_distributed(reconfig_request)
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
self.shutdown()
def shutdown(self) -> None: def shutdown(self) -> None:
if worker := self.driver_worker: if worker := self.driver_worker:
worker.shutdown() worker.shutdown()
......
...@@ -53,7 +53,12 @@ class CPUModelRunner(GPUModelRunner): ...@@ -53,7 +53,12 @@ class CPUModelRunner(GPUModelRunner):
v.gpu = v.cpu v.gpu = v.cpu
@instrument(span_name="Loading (CPU)") @instrument(span_name="Loading (CPU)")
def load_model(self, eep_scale_up: bool = False) -> None: def load_model(self, load_dummy_weights: bool = False) -> None:
if load_dummy_weights:
raise ValueError(
"Loading dummy weights (needed for elastic EP scale-up) "
"Is not supported by the CPU Model Runner."
)
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config) self.model = get_model(vllm_config=self.vllm_config)
......
...@@ -461,6 +461,8 @@ class GPUModelRunner( ...@@ -461,6 +461,8 @@ class GPUModelRunner(
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
self.eplb_state: EplbState | None = None self.eplb_state: EplbState | None = None
# NOTE(yongji): flag to temporarily disable EPLB during scaling up/down
self.eep_eplb_suppressed = False
""" """
State of the expert parallelism load balancer. State of the expert parallelism load balancer.
...@@ -2702,7 +2704,7 @@ class GPUModelRunner( ...@@ -2702,7 +2704,7 @@ class GPUModelRunner(
""" """
Step for the EPLB (Expert Parallelism Load Balancing) state. Step for the EPLB (Expert Parallelism Load Balancing) state.
""" """
if not self.parallel_config.enable_eplb: if not self.parallel_config.enable_eplb or self.eep_eplb_suppressed:
return return
assert self.eplb_state is not None assert self.eplb_state is not None
...@@ -2714,6 +2716,23 @@ class GPUModelRunner( ...@@ -2714,6 +2716,23 @@ class GPUModelRunner(
log_stats=self.parallel_config.eplb_config.log_balancedness, log_stats=self.parallel_config.eplb_config.log_balancedness,
) )
def setup_eplb_from_mapping(
self,
expanded_physical_to_logical: torch.Tensor,
old_num_physical_experts: int,
) -> None:
model = self.get_model()
assert is_mixture_of_experts(model)
self.eplb_state = EplbState.from_mapping(
model=model,
model_config=self.model_config,
device=self.device,
parallel_config=self.parallel_config,
expanded_physical_to_logical=expanded_physical_to_logical,
num_valid_physical_experts=old_num_physical_experts,
)
def _pool( def _pool(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -4175,21 +4194,16 @@ class GPUModelRunner( ...@@ -4175,21 +4194,16 @@ class GPUModelRunner(
setattr(self, config_name, new_config) setattr(self, config_name, new_config)
@instrument(span_name="Loading (GPU)") @instrument(span_name="Loading (GPU)")
def load_model(self, eep_scale_up: bool = False) -> None: def load_model(self, load_dummy_weights: bool = False) -> None:
""" """
Args: Args:
eep_scale_up: the model loading is for elastic EP scale up. load_dummy_weights: load dummy weights instead of real weights.
""" """
logger.info_once( logger.info_once(
"Starting to load model %s...", "Starting to load model %s...",
self.model_config.model, self.model_config.model,
scope="global", scope="global",
) )
global_expert_loads, old_global_expert_indices_per_model, rank_mapping = (
EplbState.get_eep_state(self.parallel_config)
if eep_scale_up
else (None, None, None)
)
if self.parallel_config.enable_eplb: if self.parallel_config.enable_eplb:
self.eplb_state = EplbState(self.parallel_config, self.device) self.eplb_state = EplbState(self.parallel_config, self.device)
...@@ -4198,6 +4212,8 @@ class GPUModelRunner( ...@@ -4198,6 +4212,8 @@ class GPUModelRunner(
try: try:
with DeviceMemoryProfiler() as m: with DeviceMemoryProfiler() as m:
time_before_load = time.perf_counter() time_before_load = time.perf_counter()
if load_dummy_weights:
self.load_config.load_format = "dummy"
model_loader = get_model_loader(self.load_config) model_loader = get_model_loader(self.load_config)
self.model = model_loader.load_model( self.model = model_loader.load_model(
vllm_config=self.vllm_config, model_config=self.model_config vllm_config=self.vllm_config, model_config=self.model_config
...@@ -4214,6 +4230,9 @@ class GPUModelRunner( ...@@ -4214,6 +4230,9 @@ class GPUModelRunner(
and is_mixture_of_experts(self.drafter.model) and is_mixture_of_experts(self.drafter.model)
and self.parallel_config.enable_eplb and self.parallel_config.enable_eplb
): ):
assert not self.parallel_config.enable_elastic_ep, (
"Elastic EP is not supported with drafter model."
)
spec_config = self.vllm_config.speculative_config spec_config = self.vllm_config.speculative_config
assert spec_config is not None assert spec_config is not None
assert spec_config.draft_model_config is not None assert spec_config.draft_model_config is not None
...@@ -4221,17 +4240,6 @@ class GPUModelRunner( ...@@ -4221,17 +4240,6 @@ class GPUModelRunner(
"EPLB is enabled for drafter model %s.", "EPLB is enabled for drafter model %s.",
spec_config.draft_model_config.model, spec_config.draft_model_config.model,
) )
global_expert_load = (
global_expert_loads[eplb_models]
if global_expert_loads
else None
)
old_global_expert_indices = (
old_global_expert_indices_per_model[eplb_models]
if old_global_expert_indices_per_model
else None
)
if self.eplb_state is None: if self.eplb_state is None:
self.eplb_state = EplbState( self.eplb_state = EplbState(
self.parallel_config, self.device self.parallel_config, self.device
...@@ -4239,9 +4247,6 @@ class GPUModelRunner( ...@@ -4239,9 +4247,6 @@ class GPUModelRunner(
self.eplb_state.add_model( self.eplb_state.add_model(
self.drafter.model, self.drafter.model,
spec_config.draft_model_config, spec_config.draft_model_config,
global_expert_load,
old_global_expert_indices,
rank_mapping,
) )
eplb_models += 1 eplb_models += 1
...@@ -4283,11 +4288,12 @@ class GPUModelRunner( ...@@ -4283,11 +4288,12 @@ class GPUModelRunner(
time_after_load - time_before_load, time_after_load - time_before_load,
scope="local", scope="local",
) )
prepare_communication_buffer_for_model(self.model) if not load_dummy_weights:
if (drafter := getattr(self, "drafter", None)) and ( prepare_communication_buffer_for_model(self.model)
drafter_model := getattr(drafter, "model", None) if (drafter := getattr(self, "drafter", None)) and (
): drafter_model := getattr(drafter, "model", None)
prepare_communication_buffer_for_model(drafter_model) ):
prepare_communication_buffer_for_model(drafter_model)
mm_config = self.model_config.multimodal_config mm_config = self.model_config.multimodal_config
self.is_multimodal_pruning_enabled = ( self.is_multimodal_pruning_enabled = (
supports_multimodal_pruning(self.get_model()) supports_multimodal_pruning(self.get_model())
...@@ -4295,26 +4301,19 @@ class GPUModelRunner( ...@@ -4295,26 +4301,19 @@ class GPUModelRunner(
and mm_config.is_multimodal_pruning_enabled() and mm_config.is_multimodal_pruning_enabled()
) )
if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: if (
is_mixture_of_experts(self.model)
and self.parallel_config.enable_eplb
and not load_dummy_weights
):
logger.info_once("EPLB is enabled for model %s.", self.model_config.model) logger.info_once("EPLB is enabled for model %s.", self.model_config.model)
global_expert_load = (
global_expert_loads[eplb_models] if global_expert_loads else None
)
old_global_expert_indices = (
old_global_expert_indices_per_model[eplb_models]
if old_global_expert_indices_per_model
else None
)
assert self.eplb_state is not None assert self.eplb_state is not None
self.eplb_state.add_model( self.eplb_state.add_model(
self.model, self.model,
self.model_config, self.model_config,
global_expert_load,
old_global_expert_indices,
rank_mapping,
) )
if self.eplb_state.is_async: if self.eplb_state.is_async:
self.eplb_state.start_async_loop(rank_mapping=rank_mapping) self.eplb_state.start_async_loop()
if ( if (
self.vllm_config.compilation_config.mode self.vllm_config.compilation_config.mode
......
...@@ -7,11 +7,10 @@ import os ...@@ -7,11 +7,10 @@ import os
from collections.abc import Callable from collections.abc import Callable
from contextlib import AbstractContextManager, nullcontext from contextlib import AbstractContextManager, nullcontext
from types import NoneType from types import NoneType
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any
import numpy as np import numpy as np
import torch import torch
import torch.distributed
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
...@@ -32,14 +31,12 @@ from vllm.distributed.kv_transfer import ( ...@@ -32,14 +31,12 @@ from vllm.distributed.kv_transfer import (
) )
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
Handle, Handle,
get_pcp_group,
get_pp_group, get_pp_group,
get_tp_group, get_tp_group,
) )
from vllm.distributed.weight_transfer import WeightTransferEngineFactory from vllm.distributed.weight_transfer import WeightTransferEngineFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
...@@ -49,7 +46,6 @@ from vllm.tracing import instrument ...@@ -49,7 +46,6 @@ from vllm.tracing import instrument
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ( from vllm.v1.outputs import (
AsyncModelRunnerOutput, AsyncModelRunnerOutput,
...@@ -124,6 +120,10 @@ class Worker(WorkerBase): ...@@ -124,6 +120,10 @@ class Worker(WorkerBase):
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
torch.set_float32_matmul_precision(precision) torch.set_float32_matmul_precision(precision)
from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor
self.elastic_ep_executor = ElasticEPScalingExecutor(self)
# Buffers saved before sleep # Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {} self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
...@@ -317,12 +317,29 @@ class Worker(WorkerBase): ...@@ -317,12 +317,29 @@ class Worker(WorkerBase):
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation. # to hijack tensor allocation.
def load_model(self) -> None: def load_model(self) -> None:
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1" dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
if dummy_weights:
(
expanded_physical_to_logical,
num_logical_experts,
old_num_physical_experts,
) = self.elastic_ep_executor.receive_expert_mapping()
num_physical_experts = expanded_physical_to_logical.shape[1]
self.parallel_config.eplb_config.num_redundant_experts = (
num_physical_experts - num_logical_experts
)
with ( with (
self._maybe_get_memory_pool_context(tag="weights"), self._maybe_get_memory_pool_context(tag="weights"),
set_current_vllm_config(self.vllm_config), set_current_vllm_config(self.vllm_config),
): ):
self.model_runner.load_model(eep_scale_up=eep_scale_up) self.model_runner.load_model(load_dummy_weights=dummy_weights)
if dummy_weights:
self.model_runner.setup_eplb_from_mapping(
expanded_physical_to_logical, old_num_physical_experts
)
self.model_runner.eep_eplb_suppressed = True
def update_config(self, overrides: dict[str, Any]) -> None: def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides) self.model_runner.update_config(overrides)
...@@ -801,227 +818,6 @@ class Worker(WorkerBase): ...@@ -801,227 +818,6 @@ class Worker(WorkerBase):
# worker will always be healthy as long as it's running. # worker will always be healthy as long as it's running.
return return
def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info(
"[Elastic EP] Starting expert resharding before scaling down..."
)
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
execute_shuffle=True,
global_expert_loads=None,
rank_mapping=rank_mapping,
)
torch.cuda.synchronize()
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _eplb_after_scale_up(
self,
old_ep_size: int,
new_ep_size: int,
global_expert_loads: list[torch.Tensor] | None,
) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding after scaling up...")
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
execute_shuffle=True,
global_expert_loads=global_expert_loads,
rank_mapping=rank_mapping,
)
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _reconfigure_parallel_config(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
"""
Update parallel config with provided reconfig_request
"""
parallel_config = self.vllm_config.parallel_config
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
def _reconfigure_moe(
self, old_ep_size: int, new_ep_size: int
) -> list[torch.Tensor] | None:
"""
Reconfigure MoE modules with provided reconfig_request
Return the global expert load if new_ep_size > old_ep_size,
otherwise None
"""
from vllm.distributed.parallel_state import (
get_dp_group,
get_ep_group,
prepare_communication_buffer_for_model,
)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEParallelConfig,
)
parallel_config = self.vllm_config.parallel_config
def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
return [
module
for module in model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
assert all(
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules
), "All MoE modules must have the same number of experts"
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
tp_size = get_tp_group().world_size
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
sp_size = tp_size if is_sequence_parallel else 1
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=tp_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
sp_size_=sp_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
return moe_modules
model_moe_modules = get_moe_modules(self.model_runner.model)
num_local_experts = model_moe_modules[0].moe_config.num_local_experts
update_moe_modules(model_moe_modules, num_local_experts)
drafter_model = None
if hasattr(self.model_runner, "drafter") and hasattr(
self.model_runner.drafter, "model"
):
drafter_model = self.model_runner.drafter.model
if drafter_model is not None and is_mixture_of_experts(drafter_model):
drafter_moe_modules = get_moe_modules(drafter_model)
# Check if drafter and model have matching configs
assert (
drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
), "Drafter and model configs should be the same"
update_moe_modules(drafter_moe_modules, num_local_experts)
if new_ep_size < old_ep_size:
num_local_physical_experts = num_local_experts
assert self.model_runner.eplb_state is not None
new_physical_experts = (
self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined]
)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts
- self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined]
)
global_expert_loads = None
else:
num_local_physical_experts_tensor = torch.tensor(
[num_local_experts], dtype=torch.int32, device="cpu"
)
torch.distributed.broadcast(
num_local_physical_experts_tensor,
group=get_ep_group().cpu_group,
group_src=0,
)
num_local_physical_experts = int(num_local_physical_experts_tensor.item())
new_physical_experts = num_local_physical_experts * new_ep_size
assert self.model_runner.eplb_state is not None
global_expert_loads_any = self.model_runner.eplb_state.rearrange(
execute_shuffle=False
)
global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts - global_expert_loads[0].shape[1]
)
prepare_communication_buffer_for_model(self.model_runner.model)
if drafter_model is not None:
prepare_communication_buffer_for_model(drafter_model)
self.model_runner.model.update_physical_experts_metadata(
num_physical_experts=new_physical_experts,
num_local_physical_experts=num_local_physical_experts,
)
return global_expert_loads
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (
cleanup_dist_env_and_memory,
get_ep_group,
)
old_ep_size = get_ep_group().world_size
old_ep_rank = get_ep_group().rank
new_ep_size = (
reconfig_request.new_data_parallel_size
* get_tp_group().world_size
* get_pp_group().world_size
)
if new_ep_size < old_ep_size:
self._eplb_before_scale_down(old_ep_size, new_ep_size)
cleanup_dist_env_and_memory()
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
assert old_ep_rank >= new_ep_size
# shutdown
return
self._reconfigure_parallel_config(reconfig_request)
with set_current_vllm_config(self.vllm_config):
init_worker_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
)
global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
if new_ep_size > old_ep_size:
assert global_expert_loads is not None
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
def save_sharded_state( def save_sharded_state(
self, self,
path: str, path: str,
...@@ -1118,6 +914,9 @@ class Worker(WorkerBase): ...@@ -1118,6 +914,9 @@ class Worker(WorkerBase):
if weight_transfer_engine := getattr(self, "weight_transfer_engine", None): if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
weight_transfer_engine.shutdown() weight_transfer_engine.shutdown()
def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
return self.elastic_ep_executor.execute(execute_method, *args, **kwargs)
def init_worker_distributed_environment( def init_worker_distributed_environment(
vllm_config: VllmConfig, vllm_config: VllmConfig,
......
...@@ -66,6 +66,23 @@ class WorkspaceManager: ...@@ -66,6 +66,23 @@ class WorkspaceManager:
], ],
) )
def unlock(self) -> None:
"""Unlock the workspace to allow growth.
This is used during elastic EP scaling when the workspace size
needs to grow due to changes in the number of experts.
"""
self._locked = False
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(
"[WORKSPACE DEBUG] Workspace unlocked. Current sizes: %s",
[
self._workspace_size_bytes(ws) / _MB
for ws in self._current_workspaces
if ws is not None
],
)
def is_locked(self) -> bool: def is_locked(self) -> bool:
"""Check if workspace is locked.""" """Check if workspace is locked."""
return self._locked return self._locked
...@@ -242,6 +259,17 @@ def lock_workspace() -> None: ...@@ -242,6 +259,17 @@ def lock_workspace() -> None:
current_workspace_manager().lock() current_workspace_manager().lock()
def unlock_workspace() -> None:
"""Unlock the workspace to allow growth.
This is used during elastic EP scaling when the workspace size
needs to grow due to changes in the number of experts.
After scaling operations complete, lock_workspace() should be
called again to prevent unexpected allocations.
"""
current_workspace_manager().unlock()
def reset_workspace_manager() -> None: def reset_workspace_manager() -> None:
"""Reset the workspace manager to uninitialized state. """Reset the workspace manager to uninitialized state.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment