Unverified Commit dea26833 authored by Itay Alroy's avatar Itay Alroy Committed by GitHub
Browse files

[1/N] Elastic EP Milestone 2 (#34861)


Signed-off-by: default avatarYongji Wu <wuyongji317@gmail.com>
Signed-off-by: default avatarItay Alroy <ialroy@nvidia.com>
Signed-off-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: default avatarRon Tourgeman <rtourgeman@nvidia.com>
Co-authored-by: default avatarYongji Wu <wuyongji317@gmail.com>
Co-authored-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: default avatarRon Tourgeman <rtourgeman@nvidia.com>
parent 90805ff4
...@@ -29,6 +29,15 @@ PauseMode = Literal["abort", "wait", "keep"] ...@@ -29,6 +29,15 @@ PauseMode = Literal["abort", "wait", "keep"]
# so form part of the external API. # so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error") FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
EEP_NOTIFICATION_CALL_ID = -1
class EEPNotificationType(enum.Enum):
NEW_CORE_ENGINES_INIT_READY = "NEW_CORE_ENGINES_INIT_READY"
NEW_CORE_ENGINES_WEIGHTS_INIT_READY = "NEW_CORE_ENGINES_WEIGHTS_INIT_READY"
RECONFIGURE_FINISHED = "RECONFIGURE_FINISHED"
SHUTDOWN_COMPLETE = "SHUTDOWN_COMPLETE"
class FinishReason(enum.IntEnum): class FinishReason(enum.IntEnum):
""" """
...@@ -235,6 +244,11 @@ class ReconfigureDistributedRequest(msgspec.Struct): ...@@ -235,6 +244,11 @@ class ReconfigureDistributedRequest(msgspec.Struct):
new_data_parallel_rank_local: int new_data_parallel_rank_local: int
new_data_parallel_master_ip: str new_data_parallel_master_ip: str
new_data_parallel_master_port: int new_data_parallel_master_port: int
new_data_parallel_master_port_list: list[int]
new_stateless_world_group_port_list: list[list[int]]
new_stateless_dp_group_port_list: list[list[int]]
new_stateless_ep_group_port_list: list[list[int]]
new_stateless_eplb_group_port_list: list[list[int]]
class ReconfigureRankType(enum.IntEnum): class ReconfigureRankType(enum.IntEnum):
......
...@@ -20,6 +20,7 @@ from vllm.distributed.weight_transfer.base import ( ...@@ -20,6 +20,7 @@ from vllm.distributed.weight_transfer.base import (
) )
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient, StreamingInput from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -647,7 +648,11 @@ class AsyncLLM(EngineClient): ...@@ -647,7 +648,11 @@ class AsyncLLM(EngineClient):
engine_core = self.engine_core engine_core = self.engine_core
output_processor = self.output_processor output_processor = self.output_processor
log_stats = self.log_stats log_stats = self.log_stats
logger_manager = self.logger_manager # We use a mutable list for logger_manager so that it can be updated
# during elastic EP scaling (see scale_elastic_ep) without creating
# a circular reference via self.
self._logger_ref = [self.logger_manager]
logger_ref = self._logger_ref
renderer = self.renderer renderer = self.renderer
chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
...@@ -691,8 +696,8 @@ class AsyncLLM(EngineClient): ...@@ -691,8 +696,8 @@ class AsyncLLM(EngineClient):
# 4) Logging. # 4) Logging.
# TODO(rob): make into a coroutine and launch it in # TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial. # background thread once Prometheus overhead is non-trivial.
if logger_manager: if logger_ref[0]:
logger_manager.record( logger_ref[0].record(
engine_idx=outputs.engine_index, engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
...@@ -976,17 +981,13 @@ class AsyncLLM(EngineClient): ...@@ -976,17 +981,13 @@ class AsyncLLM(EngineClient):
new_data_parallel_size, new_data_parallel_size,
) )
return return
logger.info(
"Waiting for requests to drain before scaling up to %s engines...", if envs.VLLM_ELASTIC_EP_DRAIN_REQUESTS:
new_data_parallel_size, logger.info(
) "VLLM_ELASTIC_EP_DRAIN_REQUESTS is set, "
await self.wait_for_requests_to_drain(drain_timeout) "waiting for requests to drain before scaling"
logger.info( )
"Requests have been drained, proceeding with scale to %s engines", await self.wait_for_requests_to_drain(drain_timeout)
new_data_parallel_size,
)
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
# recreate stat loggers # recreate stat loggers
if new_data_parallel_size > old_data_parallel_size and self.log_stats: if new_data_parallel_size > old_data_parallel_size and self.log_stats:
...@@ -999,6 +1000,18 @@ class AsyncLLM(EngineClient): ...@@ -999,6 +1000,18 @@ class AsyncLLM(EngineClient):
engine_idxs=list(range(new_data_parallel_size)), engine_idxs=list(range(new_data_parallel_size)),
custom_stat_loggers=None, custom_stat_loggers=None,
) )
# Update the mutable ref so output_handler picks up the
# new logger without creating a circular reference via self.
if hasattr(self, "_logger_ref"):
self._logger_ref[0] = self.logger_manager
self.logger_manager.log_engine_initialized()
set_scaling_elastic_ep(True)
try:
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
finally:
set_scaling_elastic_ep(False)
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
......
...@@ -71,6 +71,9 @@ class DPCoordinator: ...@@ -71,6 +71,9 @@ class DPCoordinator:
) )
local_only_eng = dp_size == parallel_config.data_parallel_size_local local_only_eng = dp_size == parallel_config.data_parallel_size_local
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
local_only_eng = False
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host) back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
...@@ -201,6 +204,7 @@ class DPCoordinatorProc: ...@@ -201,6 +204,7 @@ class DPCoordinatorProc:
poller = zmq.Poller() poller = zmq.Poller()
poller.register(publish_front, zmq.POLLIN) poller.register(publish_front, zmq.POLLIN)
poller.register(publish_back, zmq.POLLIN)
poller.register(output_back, zmq.POLLIN) poller.register(output_back, zmq.POLLIN)
last_publish_time = 0 last_publish_time = 0
while True: while True:
...@@ -231,6 +235,22 @@ class DPCoordinatorProc: ...@@ -231,6 +235,22 @@ class DPCoordinatorProc:
events = dict(events) events = dict(events)
wave_state_changed = False wave_state_changed = False
if publish_back in events:
buffer = publish_back.recv()
if buffer == b"\x01":
# NOTE(yongji): newly started engine subscribed
# We need to send READY message here instead of receiving
# SCALE_ELASTIC_EP notification from engine core client
# as SCALE_ELASTIC_EP is only sent when
# new engines finished initialization.
# Subscription message, on the other hand, is sent
# by each engine during initialization
publish_back.send(b"READY")
else:
logger.error(
"DP Coordinator receives unexpected message from engines"
)
if publish_front in events: if publish_front in events:
buffer = publish_front.recv() buffer = publish_front.recv()
if buffer in (b"\x01", b"\x00"): if buffer in (b"\x01", b"\x00"):
...@@ -259,7 +279,6 @@ class DPCoordinatorProc: ...@@ -259,7 +279,6 @@ class DPCoordinatorProc:
# current_wave # current_wave
# we note that 0 is the wave number for the new # we note that 0 is the wave number for the new
# engine # engine
engines_running = False
logger.info( logger.info(
"DPCoordinator scaled up from %s to %s engines", "DPCoordinator scaled up from %s to %s engines",
current_count, current_count,
......
...@@ -17,6 +17,7 @@ from typing import Any, TypeVar, cast ...@@ -17,6 +17,7 @@ from typing import Any, TypeVar, cast
import msgspec import msgspec
import zmq import zmq
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.envs import enable_envs_cache from vllm.envs import enable_envs_cache
...@@ -44,6 +45,8 @@ from vllm.v1.core.kv_cache_utils import ( ...@@ -44,6 +45,8 @@ from vllm.v1.core.kv_cache_utils import (
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ( from vllm.v1.engine import (
EEP_NOTIFICATION_CALL_ID,
EEPNotificationType,
EngineCoreOutput, EngineCoreOutput,
EngineCoreOutputs, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequest,
...@@ -110,6 +113,9 @@ class EngineCore: ...@@ -110,6 +113,9 @@ class EngineCore:
self.available_gpu_memory_for_kv_cache = -1 self.available_gpu_memory_for_kv_cache = -1
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self._eep_scale_up_before_kv_init()
# Setup KV Caches and update CacheConfig after profiling. # Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
vllm_config vllm_config
...@@ -233,12 +239,10 @@ class EngineCore: ...@@ -233,12 +239,10 @@ class EngineCore:
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs) has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
if has_kv_cache: if has_kv_cache:
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1": if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
dp_group = getattr(self, "dp_group", None) # NOTE(yongji): should already be set
assert dp_group is not None # during _eep_scale_up_before_kv_init
self.available_gpu_memory_for_kv_cache = ( assert self.available_gpu_memory_for_kv_cache > 0
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
)
available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len( available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
kv_cache_specs kv_cache_specs
) )
...@@ -752,11 +756,22 @@ class EngineCore: ...@@ -752,11 +756,22 @@ class EngineCore:
self.structured_output_manager.grammar_init(req) self.structured_output_manager.grammar_init(req)
return req, request.current_wave return req, request.current_wave
def _eep_scale_up_before_kv_init(self):
raise NotImplementedError
def _eep_send_engine_core_notification(
self,
notification_type: EEPNotificationType,
vllm_config: VllmConfig | None = None,
):
raise NotImplementedError
class EngineCoreProc(EngineCore): class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process.""" """ZMQ-wrapper for running EngineCore in background process."""
ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD" ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
addresses: EngineZmqAddresses
@instrument(span_name="EngineCoreProc init") @instrument(span_name="EngineCoreProc init")
def __init__( def __init__(
...@@ -807,6 +822,13 @@ class EngineCoreProc(EngineCore): ...@@ -807,6 +822,13 @@ class EngineCoreProc(EngineCore):
# and "hybrid" LB modes. # and "hybrid" LB modes.
self.publish_dp_lb_stats = internal_dp_balancing self.publish_dp_lb_stats = internal_dp_balancing
self.addresses = addresses
self.process_input_queue_block = True
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self._eep_send_engine_core_notification(
EEPNotificationType.NEW_CORE_ENGINES_INIT_READY,
vllm_config=vllm_config,
)
self._init_data_parallel(vllm_config) self._init_data_parallel(vllm_config)
super().__init__( super().__init__(
...@@ -1119,8 +1141,14 @@ class EngineCoreProc(EngineCore): ...@@ -1119,8 +1141,14 @@ class EngineCoreProc(EngineCore):
if logger.isEnabledFor(DEBUG): if logger.isEnabledFor(DEBUG):
logger.debug("EngineCore waiting for work.") logger.debug("EngineCore waiting for work.")
waited = True waited = True
req = self.input_queue.get() block = self.process_input_queue_block
self._handle_client_request(*req) try:
req = self.input_queue.get(block=block)
self._handle_client_request(*req)
except queue.Empty:
break
if not block:
break
if waited: if waited:
logger.debug("EngineCore loop active.") logger.debug("EngineCore loop active.")
...@@ -1290,6 +1318,11 @@ class EngineCoreProc(EngineCore): ...@@ -1290,6 +1318,11 @@ class EngineCoreProc(EngineCore):
for input_socket, _ in poller.poll(): for input_socket, _ in poller.poll():
# (RequestType, RequestData) # (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(copy=False) type_frame, *data_frames = input_socket.recv_multipart(copy=False)
# NOTE(yongji): ignore READY message sent by DP coordinator
# that is used to notify newly started engines
if type_frame.buffer == b"READY":
assert input_socket == coord_socket
continue
request_type = EngineCoreRequestType(bytes(type_frame.buffer)) request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data. # Deserialize the request data.
...@@ -1488,6 +1521,10 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1488,6 +1521,10 @@ class DPEngineCoreProc(EngineCoreProc):
self.current_wave = 0 self.current_wave = 0
self.last_counts = (0, 0) self.last_counts = (0, 0)
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
self.eep_scaling_state: ElasticEPScalingState | None = None
# Initialize the engine. # Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__( super().__init__(
...@@ -1511,7 +1548,9 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1511,7 +1548,9 @@ class DPEngineCoreProc(EngineCoreProc):
assert 0 <= local_dp_rank <= dp_rank < dp_size assert 0 <= local_dp_rank <= dp_rank < dp_size
self.dp_rank = dp_rank self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() self.dp_group, self.dp_store = (
vllm_config.parallel_config.stateless_init_dp_group(return_store=True)
)
def shutdown(self): def shutdown(self):
super().shutdown() super().shutdown()
...@@ -1574,7 +1613,12 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1574,7 +1613,12 @@ class DPEngineCoreProc(EngineCoreProc):
# 1) Poll the input queue until there is work to do. # 1) Poll the input queue until there is work to do.
self._process_input_queue() self._process_input_queue()
# 2) Step the engine core. if self.eep_scaling_state is not None:
_ = self.eep_scaling_state.progress()
if self.eep_scaling_state.is_complete():
self.process_input_queue_block = True
self.eep_scaling_state = None
executed = self._process_engine_step() executed = self._process_engine_step()
self._maybe_publish_request_counts() self._maybe_publish_request_counts()
...@@ -1624,54 +1668,129 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1624,54 +1668,129 @@ class DPEngineCoreProc(EngineCoreProc):
def reinitialize_distributed( def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest self, reconfig_request: ReconfigureDistributedRequest
) -> None: ) -> None:
stateless_destroy_torch_distributed_process_group(self.dp_group) from copy import deepcopy
self.shutdown()
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
parallel_config = self.vllm_config.parallel_config
old_dp_size = parallel_config.data_parallel_size new_parallel_config = deepcopy(self.vllm_config.parallel_config)
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size old_dp_size = new_parallel_config.data_parallel_size
if reconfig_request.new_data_parallel_rank != -1: new_parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank if (
# local rank specifies device visibility, it should not be changed reconfig_request.new_data_parallel_rank
assert ( != ReconfigureRankType.KEEP_CURRENT_RANK
reconfig_request.new_data_parallel_rank_local ):
== ReconfigureRankType.KEEP_CURRENT_RANK new_parallel_config.data_parallel_rank = (
) reconfig_request.new_data_parallel_rank
parallel_config.data_parallel_master_ip = ( )
new_parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip reconfig_request.new_data_parallel_master_ip
) )
parallel_config.data_parallel_master_port = ( new_parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port reconfig_request.new_data_parallel_master_port
) )
if reconfig_request.new_data_parallel_rank != -2: new_parallel_config._data_parallel_master_port_list = (
self.dp_rank = parallel_config.data_parallel_rank reconfig_request.new_data_parallel_master_port_list
self.dp_group = parallel_config.stateless_init_dp_group()
reconfig_request.new_data_parallel_master_port = (
parallel_config.data_parallel_master_port
) )
self.model_executor.reinitialize_distributed(reconfig_request) is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size
if reconfig_request.new_data_parallel_size > old_dp_size: is_shutdown = (
assert self.available_gpu_memory_for_kv_cache > 0
# pass available_gpu_memory_for_kv_cache from existing
# engine-cores to new engine-cores so they can directly
# use it in _initialize_kv_caches() rather than profiling.
ParallelConfig.sync_kv_cache_memory_size(
self.dp_group, self.available_gpu_memory_for_kv_cache
)
# NOTE(yongji): newly joined workers require dummy_run even
# CUDA graph is not used
self.model_executor.collective_rpc("compile_or_warm_up_model")
if (
reconfig_request.new_data_parallel_rank reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
): )
self.shutdown()
logger.info("DPEngineCoreProc %s shutdown", self.dp_rank) self.eep_scaling_state = ElasticEPScalingState(
model_executor=self.model_executor,
engine_core=self,
vllm_config=self.vllm_config,
new_parallel_config=new_parallel_config,
worker_type="removing" if is_shutdown else "existing",
scale_type="scale_down" if is_scale_down else "scale_up",
reconfig_request=reconfig_request,
)
self.process_input_queue_block = False
logger.info(
"[Elastic EP] Received reconfiguration request and starting scaling up/down"
)
def _eep_send_engine_core_notification(
self,
notification_type: EEPNotificationType,
vllm_config: VllmConfig | None = None,
):
"""
Send notifications to EngineCoreClient, which can then forward
the notifications to other engine core processes. It is used for:
1) In scale up: new core engines to notify exisiting core engines
that they are ready;
2) In scale down: removing core engines to notify EngineCoreClient
so EngineCoreClient can release their ray placement groups;
3) Both scale up/down: to notify EngineCoreClient that exisiting
core engines have already switched to the new parallel setup.
"""
if vllm_config is None:
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
else: else:
logger.info( dp_rank = vllm_config.parallel_config.data_parallel_rank
"Distributed environment reinitialized for DP rank %s", self.dp_rank notification_data = (notification_type.value, dp_rank)
outputs = EngineCoreOutputs(
utility_output=UtilityOutput(
call_id=EEP_NOTIFICATION_CALL_ID,
result=UtilityResult(notification_data),
) )
)
outputs.engine_index = self.engine_index
if hasattr(self, "output_thread") and self.output_thread.is_alive():
self.output_queue.put_nowait((0, outputs))
else:
encoder = MsgpackEncoder()
with (
zmq.Context() as ctx,
make_zmq_socket(
ctx, self.addresses.outputs[0], zmq.PUSH, linger=4000
) as socket,
):
socket.send_multipart(encoder.encode(outputs))
def eep_handle_engine_core_notification(
self, notification_type: str | EEPNotificationType
):
"""
Handle notification received from EngineCoreClient
(forwarded from new core engines).
"""
assert self.eep_scaling_state is not None
if isinstance(notification_type, str):
notification_type = EEPNotificationType(notification_type)
self.eep_scaling_state.handle_notification(notification_type)
def _eep_scale_up_before_kv_init(self):
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
self.eep_scaling_state = ElasticEPScalingState(
model_executor=self.model_executor,
engine_core=self,
vllm_config=self.vllm_config,
new_parallel_config=self.vllm_config.parallel_config,
worker_type="new",
scale_type="scale_up",
reconfig_request=None,
)
self.model_executor.collective_rpc("init_device")
self.model_executor.collective_rpc("load_model")
self._eep_send_engine_core_notification(
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("receive_weights",)
)
self.available_gpu_memory_for_kv_cache = (
ParallelConfig.sync_kv_cache_memory_size(self.dp_group, -1)
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("prepare_new_worker",)
)
self.process_input_queue_block = False
class EngineCoreActorMixin: class EngineCoreActorMixin:
......
This diff is collapsed.
...@@ -277,6 +277,8 @@ class CoreEngineActorManager: ...@@ -277,6 +277,8 @@ class CoreEngineActorManager:
else: else:
ray.init() ray.init()
vllm_config.parallel_config.allocate_elastic_ep_ports()
if placement_groups is not None: if placement_groups is not None:
assert local_dp_ranks is not None, ( assert local_dp_ranks is not None, (
"local_dp_ranks must be provided if placement_groups is provided" "local_dp_ranks must be provided if placement_groups is provided"
...@@ -584,6 +586,8 @@ class CoreEngineActorManager: ...@@ -584,6 +586,8 @@ class CoreEngineActorManager:
node_ip = node.node_ip node_ip = node.node_ip
node_id = node.node_id node_id = node.node_id
if device_str not in available_resources[node_id]:
continue
available_gpus = int(available_resources[node_id][device_str]) available_gpus = int(available_resources[node_id][device_str])
# Get total GPUs on this node from the node's resources # Get total GPUs on this node from the node's resources
...@@ -773,26 +777,15 @@ class CoreEngineActorManager: ...@@ -773,26 +777,15 @@ class CoreEngineActorManager:
ray.util.remove_placement_group(pg) ray.util.remove_placement_group(pg)
@contextlib.contextmanager def get_engine_zmq_addresses(
def launch_core_engines(
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
num_api_servers: int = 1, num_api_servers: int = 1,
) -> Iterator[ ) -> EngineZmqAddresses:
tuple[ """Allocate ZMQ addresses for engine-client communication."""
CoreEngineProcManager | CoreEngineActorManager | None,
DPCoordinator | None,
EngineZmqAddresses,
]
]:
"""Launch engine and DP coordinator processes as needed."""
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local local_engine_count = parallel_config.data_parallel_size_local
local_start_index = parallel_config.data_parallel_rank_local local_start_index = parallel_config.data_parallel_rank_local
dp_rank = parallel_config.data_parallel_rank dp_size = parallel_config.data_parallel_size
host = parallel_config.data_parallel_master_ip host = parallel_config.data_parallel_master_ip
local_engines_only = parallel_config.local_engines_only local_engines_only = parallel_config.local_engines_only
...@@ -806,9 +799,11 @@ def launch_core_engines( ...@@ -806,9 +799,11 @@ def launch_core_engines(
client_local_only = ( client_local_only = (
offline_mode or local_engines_only or (local_engine_count == dp_size) offline_mode or local_engines_only or (local_engine_count == dp_size)
) )
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
client_local_only = False
# Set up input and output addresses. return EngineZmqAddresses(
addresses = EngineZmqAddresses(
inputs=[ inputs=[
get_engine_client_zmq_addr(client_local_only, host) get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers) for _ in range(num_api_servers)
...@@ -819,6 +814,33 @@ def launch_core_engines( ...@@ -819,6 +814,33 @@ def launch_core_engines(
], ],
) )
@contextlib.contextmanager
def launch_core_engines(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
addresses: EngineZmqAddresses,
num_api_servers: int = 1,
) -> Iterator[
tuple[
CoreEngineProcManager | CoreEngineActorManager | None,
DPCoordinator | None,
EngineZmqAddresses,
]
]:
"""Launch engine and DP coordinator processes as needed."""
parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local
local_start_index = parallel_config.data_parallel_rank_local
dp_rank = parallel_config.data_parallel_rank
host = parallel_config.data_parallel_master_ip
local_engines_only = parallel_config.local_engines_only
offline_mode = local_start_index is not None
# Run the DP Coordinator process with rank 0 when in online DP mode. # Run the DP Coordinator process with rank 0 when in online DP mode.
# The coordinator is needed for: # The coordinator is needed for:
# 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing # 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
...@@ -885,6 +907,10 @@ def launch_core_engines( ...@@ -885,6 +907,10 @@ def launch_core_engines(
# will be False. # will be False.
handshake_local_only = offline_mode or local_engine_count == dp_size handshake_local_only = offline_mode or local_engine_count == dp_size
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
handshake_local_only = False
handshake_address = get_engine_client_zmq_addr( handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port handshake_local_only, host, parallel_config.data_parallel_rpc_port
) )
......
...@@ -38,6 +38,7 @@ from vllm.distributed.parallel_state import ( ...@@ -38,6 +38,7 @@ from vllm.distributed.parallel_state import (
get_pcp_group, get_pcp_group,
get_pp_group, get_pp_group,
get_tp_group, get_tp_group,
model_parallel_is_initialized,
) )
from vllm.envs import enable_envs_cache from vllm.envs import enable_envs_cache
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -580,17 +581,20 @@ class WorkerProc: ...@@ -580,17 +581,20 @@ class WorkerProc:
) )
self.async_output_copy_thread.start() self.async_output_copy_thread.start()
# Initialize device
self.worker.init_device()
# Set process title and log prefix
self.setup_proc_title_and_log_prefix( self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel enable_ep=vllm_config.parallel_config.enable_expert_parallel
) )
# Load model # Load model
self._init_message_queues(input_shm_handle, vllm_config) self._init_message_queues(input_shm_handle, vllm_config)
self.worker.load_model() is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
if not is_eep_new_worker:
self.worker.init_device()
# Update process title now that parallel groups are initialized
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
self.worker.load_model()
# Enable environment variable cache (e.g. assume no more # Enable environment variable cache (e.g. assume no more
# environment variable overrides after this point) # environment variable overrides after this point)
...@@ -885,6 +889,13 @@ class WorkerProc: ...@@ -885,6 +889,13 @@ class WorkerProc:
@staticmethod @staticmethod
def setup_proc_title_and_log_prefix(enable_ep: bool) -> None: def setup_proc_title_and_log_prefix(enable_ep: bool) -> None:
# Check if parallel groups are initialized first
if not model_parallel_is_initialized():
# Parallel groups not yet initialized, use default process name
set_process_title(name="Worker")
decorate_logs("Worker")
return
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
dp_rank = get_dp_group().rank_in_group dp_rank = get_dp_group().rank_in_group
pp_size = get_pp_group().world_size pp_size = get_pp_group().world_size
......
...@@ -382,8 +382,10 @@ class RayDistributedExecutor(Executor): ...@@ -382,8 +382,10 @@ class RayDistributedExecutor(Executor):
all_kwargs.append(kwargs) all_kwargs.append(kwargs)
self.collective_rpc("init_worker", args=(all_kwargs,)) self.collective_rpc("init_worker", args=(all_kwargs,))
self.collective_rpc("init_device") is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
self.collective_rpc("load_model") if not is_eep_new_worker:
self.collective_rpc("init_device")
self.collective_rpc("load_model")
for pp_rank in range(self.parallel_config.pipeline_parallel_size): for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([]) self.pp_tp_workers.append([])
......
This diff is collapsed.
...@@ -53,7 +53,12 @@ class CPUModelRunner(GPUModelRunner): ...@@ -53,7 +53,12 @@ class CPUModelRunner(GPUModelRunner):
v.gpu = v.cpu v.gpu = v.cpu
@instrument(span_name="Loading (CPU)") @instrument(span_name="Loading (CPU)")
def load_model(self, eep_scale_up: bool = False) -> None: def load_model(self, load_dummy_weights: bool = False) -> None:
if load_dummy_weights:
raise ValueError(
"Loading dummy weights (needed for elastic EP scale-up) "
"Is not supported by the CPU Model Runner."
)
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config) self.model = get_model(vllm_config=self.vllm_config)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment