Unverified Commit dea26833 authored by Itay Alroy's avatar Itay Alroy Committed by GitHub
Browse files

[1/N] Elastic EP Milestone 2 (#34861)


Signed-off-by: default avatarYongji Wu <wuyongji317@gmail.com>
Signed-off-by: default avatarItay Alroy <ialroy@nvidia.com>
Signed-off-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: default avatarRon Tourgeman <rtourgeman@nvidia.com>
Co-authored-by: default avatarYongji Wu <wuyongji317@gmail.com>
Co-authored-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: default avatarRon Tourgeman <rtourgeman@nvidia.com>
parent 90805ff4
...@@ -29,6 +29,15 @@ PauseMode = Literal["abort", "wait", "keep"] ...@@ -29,6 +29,15 @@ PauseMode = Literal["abort", "wait", "keep"]
# so form part of the external API. # so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error") FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
EEP_NOTIFICATION_CALL_ID = -1
class EEPNotificationType(enum.Enum):
NEW_CORE_ENGINES_INIT_READY = "NEW_CORE_ENGINES_INIT_READY"
NEW_CORE_ENGINES_WEIGHTS_INIT_READY = "NEW_CORE_ENGINES_WEIGHTS_INIT_READY"
RECONFIGURE_FINISHED = "RECONFIGURE_FINISHED"
SHUTDOWN_COMPLETE = "SHUTDOWN_COMPLETE"
class FinishReason(enum.IntEnum): class FinishReason(enum.IntEnum):
""" """
...@@ -235,6 +244,11 @@ class ReconfigureDistributedRequest(msgspec.Struct): ...@@ -235,6 +244,11 @@ class ReconfigureDistributedRequest(msgspec.Struct):
new_data_parallel_rank_local: int new_data_parallel_rank_local: int
new_data_parallel_master_ip: str new_data_parallel_master_ip: str
new_data_parallel_master_port: int new_data_parallel_master_port: int
new_data_parallel_master_port_list: list[int]
new_stateless_world_group_port_list: list[list[int]]
new_stateless_dp_group_port_list: list[list[int]]
new_stateless_ep_group_port_list: list[list[int]]
new_stateless_eplb_group_port_list: list[list[int]]
class ReconfigureRankType(enum.IntEnum): class ReconfigureRankType(enum.IntEnum):
......
...@@ -20,6 +20,7 @@ from vllm.distributed.weight_transfer.base import ( ...@@ -20,6 +20,7 @@ from vllm.distributed.weight_transfer.base import (
) )
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient, StreamingInput from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -647,7 +648,11 @@ class AsyncLLM(EngineClient): ...@@ -647,7 +648,11 @@ class AsyncLLM(EngineClient):
engine_core = self.engine_core engine_core = self.engine_core
output_processor = self.output_processor output_processor = self.output_processor
log_stats = self.log_stats log_stats = self.log_stats
logger_manager = self.logger_manager # We use a mutable list for logger_manager so that it can be updated
# during elastic EP scaling (see scale_elastic_ep) without creating
# a circular reference via self.
self._logger_ref = [self.logger_manager]
logger_ref = self._logger_ref
renderer = self.renderer renderer = self.renderer
chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
...@@ -691,8 +696,8 @@ class AsyncLLM(EngineClient): ...@@ -691,8 +696,8 @@ class AsyncLLM(EngineClient):
# 4) Logging. # 4) Logging.
# TODO(rob): make into a coroutine and launch it in # TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial. # background thread once Prometheus overhead is non-trivial.
if logger_manager: if logger_ref[0]:
logger_manager.record( logger_ref[0].record(
engine_idx=outputs.engine_index, engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
...@@ -976,17 +981,13 @@ class AsyncLLM(EngineClient): ...@@ -976,17 +981,13 @@ class AsyncLLM(EngineClient):
new_data_parallel_size, new_data_parallel_size,
) )
return return
if envs.VLLM_ELASTIC_EP_DRAIN_REQUESTS:
logger.info( logger.info(
"Waiting for requests to drain before scaling up to %s engines...", "VLLM_ELASTIC_EP_DRAIN_REQUESTS is set, "
new_data_parallel_size, "waiting for requests to drain before scaling"
) )
await self.wait_for_requests_to_drain(drain_timeout) await self.wait_for_requests_to_drain(drain_timeout)
logger.info(
"Requests have been drained, proceeding with scale to %s engines",
new_data_parallel_size,
)
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
# recreate stat loggers # recreate stat loggers
if new_data_parallel_size > old_data_parallel_size and self.log_stats: if new_data_parallel_size > old_data_parallel_size and self.log_stats:
...@@ -999,6 +1000,18 @@ class AsyncLLM(EngineClient): ...@@ -999,6 +1000,18 @@ class AsyncLLM(EngineClient):
engine_idxs=list(range(new_data_parallel_size)), engine_idxs=list(range(new_data_parallel_size)),
custom_stat_loggers=None, custom_stat_loggers=None,
) )
# Update the mutable ref so output_handler picks up the
# new logger without creating a circular reference via self.
if hasattr(self, "_logger_ref"):
self._logger_ref[0] = self.logger_manager
self.logger_manager.log_engine_initialized()
set_scaling_elastic_ep(True)
try:
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
finally:
set_scaling_elastic_ep(False)
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
......
...@@ -71,6 +71,9 @@ class DPCoordinator: ...@@ -71,6 +71,9 @@ class DPCoordinator:
) )
local_only_eng = dp_size == parallel_config.data_parallel_size_local local_only_eng = dp_size == parallel_config.data_parallel_size_local
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
local_only_eng = False
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host) back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
...@@ -201,6 +204,7 @@ class DPCoordinatorProc: ...@@ -201,6 +204,7 @@ class DPCoordinatorProc:
poller = zmq.Poller() poller = zmq.Poller()
poller.register(publish_front, zmq.POLLIN) poller.register(publish_front, zmq.POLLIN)
poller.register(publish_back, zmq.POLLIN)
poller.register(output_back, zmq.POLLIN) poller.register(output_back, zmq.POLLIN)
last_publish_time = 0 last_publish_time = 0
while True: while True:
...@@ -231,6 +235,22 @@ class DPCoordinatorProc: ...@@ -231,6 +235,22 @@ class DPCoordinatorProc:
events = dict(events) events = dict(events)
wave_state_changed = False wave_state_changed = False
if publish_back in events:
buffer = publish_back.recv()
if buffer == b"\x01":
# NOTE(yongji): newly started engine subscribed
# We need to send READY message here instead of receiving
# SCALE_ELASTIC_EP notification from engine core client
# as SCALE_ELASTIC_EP is only sent when
# new engines finished initialization.
# Subscription message, on the other hand, is sent
# by each engine during initialization
publish_back.send(b"READY")
else:
logger.error(
"DP Coordinator receives unexpected message from engines"
)
if publish_front in events: if publish_front in events:
buffer = publish_front.recv() buffer = publish_front.recv()
if buffer in (b"\x01", b"\x00"): if buffer in (b"\x01", b"\x00"):
...@@ -259,7 +279,6 @@ class DPCoordinatorProc: ...@@ -259,7 +279,6 @@ class DPCoordinatorProc:
# current_wave # current_wave
# we note that 0 is the wave number for the new # we note that 0 is the wave number for the new
# engine # engine
engines_running = False
logger.info( logger.info(
"DPCoordinator scaled up from %s to %s engines", "DPCoordinator scaled up from %s to %s engines",
current_count, current_count,
......
...@@ -17,6 +17,7 @@ from typing import Any, TypeVar, cast ...@@ -17,6 +17,7 @@ from typing import Any, TypeVar, cast
import msgspec import msgspec
import zmq import zmq
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.envs import enable_envs_cache from vllm.envs import enable_envs_cache
...@@ -44,6 +45,8 @@ from vllm.v1.core.kv_cache_utils import ( ...@@ -44,6 +45,8 @@ from vllm.v1.core.kv_cache_utils import (
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ( from vllm.v1.engine import (
EEP_NOTIFICATION_CALL_ID,
EEPNotificationType,
EngineCoreOutput, EngineCoreOutput,
EngineCoreOutputs, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequest,
...@@ -110,6 +113,9 @@ class EngineCore: ...@@ -110,6 +113,9 @@ class EngineCore:
self.available_gpu_memory_for_kv_cache = -1 self.available_gpu_memory_for_kv_cache = -1
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self._eep_scale_up_before_kv_init()
# Setup KV Caches and update CacheConfig after profiling. # Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
vllm_config vllm_config
...@@ -233,12 +239,10 @@ class EngineCore: ...@@ -233,12 +239,10 @@ class EngineCore:
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs) has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
if has_kv_cache: if has_kv_cache:
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1": if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
dp_group = getattr(self, "dp_group", None) # NOTE(yongji): should already be set
assert dp_group is not None # during _eep_scale_up_before_kv_init
self.available_gpu_memory_for_kv_cache = ( assert self.available_gpu_memory_for_kv_cache > 0
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
)
available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len( available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
kv_cache_specs kv_cache_specs
) )
...@@ -752,11 +756,22 @@ class EngineCore: ...@@ -752,11 +756,22 @@ class EngineCore:
self.structured_output_manager.grammar_init(req) self.structured_output_manager.grammar_init(req)
return req, request.current_wave return req, request.current_wave
def _eep_scale_up_before_kv_init(self):
raise NotImplementedError
def _eep_send_engine_core_notification(
self,
notification_type: EEPNotificationType,
vllm_config: VllmConfig | None = None,
):
raise NotImplementedError
class EngineCoreProc(EngineCore): class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process.""" """ZMQ-wrapper for running EngineCore in background process."""
ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD" ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
addresses: EngineZmqAddresses
@instrument(span_name="EngineCoreProc init") @instrument(span_name="EngineCoreProc init")
def __init__( def __init__(
...@@ -807,6 +822,13 @@ class EngineCoreProc(EngineCore): ...@@ -807,6 +822,13 @@ class EngineCoreProc(EngineCore):
# and "hybrid" LB modes. # and "hybrid" LB modes.
self.publish_dp_lb_stats = internal_dp_balancing self.publish_dp_lb_stats = internal_dp_balancing
self.addresses = addresses
self.process_input_queue_block = True
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self._eep_send_engine_core_notification(
EEPNotificationType.NEW_CORE_ENGINES_INIT_READY,
vllm_config=vllm_config,
)
self._init_data_parallel(vllm_config) self._init_data_parallel(vllm_config)
super().__init__( super().__init__(
...@@ -1119,8 +1141,14 @@ class EngineCoreProc(EngineCore): ...@@ -1119,8 +1141,14 @@ class EngineCoreProc(EngineCore):
if logger.isEnabledFor(DEBUG): if logger.isEnabledFor(DEBUG):
logger.debug("EngineCore waiting for work.") logger.debug("EngineCore waiting for work.")
waited = True waited = True
req = self.input_queue.get() block = self.process_input_queue_block
try:
req = self.input_queue.get(block=block)
self._handle_client_request(*req) self._handle_client_request(*req)
except queue.Empty:
break
if not block:
break
if waited: if waited:
logger.debug("EngineCore loop active.") logger.debug("EngineCore loop active.")
...@@ -1290,6 +1318,11 @@ class EngineCoreProc(EngineCore): ...@@ -1290,6 +1318,11 @@ class EngineCoreProc(EngineCore):
for input_socket, _ in poller.poll(): for input_socket, _ in poller.poll():
# (RequestType, RequestData) # (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(copy=False) type_frame, *data_frames = input_socket.recv_multipart(copy=False)
# NOTE(yongji): ignore READY message sent by DP coordinator
# that is used to notify newly started engines
if type_frame.buffer == b"READY":
assert input_socket == coord_socket
continue
request_type = EngineCoreRequestType(bytes(type_frame.buffer)) request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data. # Deserialize the request data.
...@@ -1488,6 +1521,10 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1488,6 +1521,10 @@ class DPEngineCoreProc(EngineCoreProc):
self.current_wave = 0 self.current_wave = 0
self.last_counts = (0, 0) self.last_counts = (0, 0)
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
self.eep_scaling_state: ElasticEPScalingState | None = None
# Initialize the engine. # Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__( super().__init__(
...@@ -1511,7 +1548,9 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1511,7 +1548,9 @@ class DPEngineCoreProc(EngineCoreProc):
assert 0 <= local_dp_rank <= dp_rank < dp_size assert 0 <= local_dp_rank <= dp_rank < dp_size
self.dp_rank = dp_rank self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() self.dp_group, self.dp_store = (
vllm_config.parallel_config.stateless_init_dp_group(return_store=True)
)
def shutdown(self): def shutdown(self):
super().shutdown() super().shutdown()
...@@ -1574,7 +1613,12 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1574,7 +1613,12 @@ class DPEngineCoreProc(EngineCoreProc):
# 1) Poll the input queue until there is work to do. # 1) Poll the input queue until there is work to do.
self._process_input_queue() self._process_input_queue()
# 2) Step the engine core. if self.eep_scaling_state is not None:
_ = self.eep_scaling_state.progress()
if self.eep_scaling_state.is_complete():
self.process_input_queue_block = True
self.eep_scaling_state = None
executed = self._process_engine_step() executed = self._process_engine_step()
self._maybe_publish_request_counts() self._maybe_publish_request_counts()
...@@ -1624,54 +1668,129 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1624,54 +1668,129 @@ class DPEngineCoreProc(EngineCoreProc):
def reinitialize_distributed( def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest self, reconfig_request: ReconfigureDistributedRequest
) -> None: ) -> None:
stateless_destroy_torch_distributed_process_group(self.dp_group) from copy import deepcopy
self.shutdown()
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
parallel_config = self.vllm_config.parallel_config
old_dp_size = parallel_config.data_parallel_size new_parallel_config = deepcopy(self.vllm_config.parallel_config)
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size old_dp_size = new_parallel_config.data_parallel_size
if reconfig_request.new_data_parallel_rank != -1: new_parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank if (
# local rank specifies device visibility, it should not be changed reconfig_request.new_data_parallel_rank
assert ( != ReconfigureRankType.KEEP_CURRENT_RANK
reconfig_request.new_data_parallel_rank_local ):
== ReconfigureRankType.KEEP_CURRENT_RANK new_parallel_config.data_parallel_rank = (
) reconfig_request.new_data_parallel_rank
parallel_config.data_parallel_master_ip = ( )
new_parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip reconfig_request.new_data_parallel_master_ip
) )
parallel_config.data_parallel_master_port = ( new_parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port reconfig_request.new_data_parallel_master_port
) )
if reconfig_request.new_data_parallel_rank != -2: new_parallel_config._data_parallel_master_port_list = (
self.dp_rank = parallel_config.data_parallel_rank reconfig_request.new_data_parallel_master_port_list
self.dp_group = parallel_config.stateless_init_dp_group()
reconfig_request.new_data_parallel_master_port = (
parallel_config.data_parallel_master_port
) )
self.model_executor.reinitialize_distributed(reconfig_request) is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size
if reconfig_request.new_data_parallel_size > old_dp_size: is_shutdown = (
assert self.available_gpu_memory_for_kv_cache > 0
# pass available_gpu_memory_for_kv_cache from existing
# engine-cores to new engine-cores so they can directly
# use it in _initialize_kv_caches() rather than profiling.
ParallelConfig.sync_kv_cache_memory_size(
self.dp_group, self.available_gpu_memory_for_kv_cache
)
# NOTE(yongji): newly joined workers require dummy_run even
# CUDA graph is not used
self.model_executor.collective_rpc("compile_or_warm_up_model")
if (
reconfig_request.new_data_parallel_rank reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
)
self.eep_scaling_state = ElasticEPScalingState(
model_executor=self.model_executor,
engine_core=self,
vllm_config=self.vllm_config,
new_parallel_config=new_parallel_config,
worker_type="removing" if is_shutdown else "existing",
scale_type="scale_down" if is_scale_down else "scale_up",
reconfig_request=reconfig_request,
)
self.process_input_queue_block = False
logger.info(
"[Elastic EP] Received reconfiguration request and starting scaling up/down"
)
def _eep_send_engine_core_notification(
self,
notification_type: EEPNotificationType,
vllm_config: VllmConfig | None = None,
): ):
self.shutdown() """
logger.info("DPEngineCoreProc %s shutdown", self.dp_rank) Send notifications to EngineCoreClient, which can then forward
the notifications to other engine core processes. It is used for:
1) In scale up: new core engines to notify exisiting core engines
that they are ready;
2) In scale down: removing core engines to notify EngineCoreClient
so EngineCoreClient can release their ray placement groups;
3) Both scale up/down: to notify EngineCoreClient that exisiting
core engines have already switched to the new parallel setup.
"""
if vllm_config is None:
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
else: else:
logger.info( dp_rank = vllm_config.parallel_config.data_parallel_rank
"Distributed environment reinitialized for DP rank %s", self.dp_rank notification_data = (notification_type.value, dp_rank)
outputs = EngineCoreOutputs(
utility_output=UtilityOutput(
call_id=EEP_NOTIFICATION_CALL_ID,
result=UtilityResult(notification_data),
)
)
outputs.engine_index = self.engine_index
if hasattr(self, "output_thread") and self.output_thread.is_alive():
self.output_queue.put_nowait((0, outputs))
else:
encoder = MsgpackEncoder()
with (
zmq.Context() as ctx,
make_zmq_socket(
ctx, self.addresses.outputs[0], zmq.PUSH, linger=4000
) as socket,
):
socket.send_multipart(encoder.encode(outputs))
def eep_handle_engine_core_notification(
self, notification_type: str | EEPNotificationType
):
"""
Handle notification received from EngineCoreClient
(forwarded from new core engines).
"""
assert self.eep_scaling_state is not None
if isinstance(notification_type, str):
notification_type = EEPNotificationType(notification_type)
self.eep_scaling_state.handle_notification(notification_type)
def _eep_scale_up_before_kv_init(self):
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
self.eep_scaling_state = ElasticEPScalingState(
model_executor=self.model_executor,
engine_core=self,
vllm_config=self.vllm_config,
new_parallel_config=self.vllm_config.parallel_config,
worker_type="new",
scale_type="scale_up",
reconfig_request=None,
)
self.model_executor.collective_rpc("init_device")
self.model_executor.collective_rpc("load_model")
self._eep_send_engine_core_notification(
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("receive_weights",)
)
self.available_gpu_memory_for_kv_cache = (
ParallelConfig.sync_kv_cache_memory_size(self.dp_group, -1)
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("prepare_new_worker",)
) )
self.process_input_queue_block = False
class EngineCoreActorMixin: class EngineCoreActorMixin:
......
This diff is collapsed.
...@@ -277,6 +277,8 @@ class CoreEngineActorManager: ...@@ -277,6 +277,8 @@ class CoreEngineActorManager:
else: else:
ray.init() ray.init()
vllm_config.parallel_config.allocate_elastic_ep_ports()
if placement_groups is not None: if placement_groups is not None:
assert local_dp_ranks is not None, ( assert local_dp_ranks is not None, (
"local_dp_ranks must be provided if placement_groups is provided" "local_dp_ranks must be provided if placement_groups is provided"
...@@ -584,6 +586,8 @@ class CoreEngineActorManager: ...@@ -584,6 +586,8 @@ class CoreEngineActorManager:
node_ip = node.node_ip node_ip = node.node_ip
node_id = node.node_id node_id = node.node_id
if device_str not in available_resources[node_id]:
continue
available_gpus = int(available_resources[node_id][device_str]) available_gpus = int(available_resources[node_id][device_str])
# Get total GPUs on this node from the node's resources # Get total GPUs on this node from the node's resources
...@@ -773,26 +777,15 @@ class CoreEngineActorManager: ...@@ -773,26 +777,15 @@ class CoreEngineActorManager:
ray.util.remove_placement_group(pg) ray.util.remove_placement_group(pg)
@contextlib.contextmanager def get_engine_zmq_addresses(
def launch_core_engines(
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
num_api_servers: int = 1, num_api_servers: int = 1,
) -> Iterator[ ) -> EngineZmqAddresses:
tuple[ """Allocate ZMQ addresses for engine-client communication."""
CoreEngineProcManager | CoreEngineActorManager | None,
DPCoordinator | None,
EngineZmqAddresses,
]
]:
"""Launch engine and DP coordinator processes as needed."""
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local local_engine_count = parallel_config.data_parallel_size_local
local_start_index = parallel_config.data_parallel_rank_local local_start_index = parallel_config.data_parallel_rank_local
dp_rank = parallel_config.data_parallel_rank dp_size = parallel_config.data_parallel_size
host = parallel_config.data_parallel_master_ip host = parallel_config.data_parallel_master_ip
local_engines_only = parallel_config.local_engines_only local_engines_only = parallel_config.local_engines_only
...@@ -806,9 +799,11 @@ def launch_core_engines( ...@@ -806,9 +799,11 @@ def launch_core_engines(
client_local_only = ( client_local_only = (
offline_mode or local_engines_only or (local_engine_count == dp_size) offline_mode or local_engines_only or (local_engine_count == dp_size)
) )
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
client_local_only = False
# Set up input and output addresses. return EngineZmqAddresses(
addresses = EngineZmqAddresses(
inputs=[ inputs=[
get_engine_client_zmq_addr(client_local_only, host) get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers) for _ in range(num_api_servers)
...@@ -819,6 +814,33 @@ def launch_core_engines( ...@@ -819,6 +814,33 @@ def launch_core_engines(
], ],
) )
@contextlib.contextmanager
def launch_core_engines(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
addresses: EngineZmqAddresses,
num_api_servers: int = 1,
) -> Iterator[
tuple[
CoreEngineProcManager | CoreEngineActorManager | None,
DPCoordinator | None,
EngineZmqAddresses,
]
]:
"""Launch engine and DP coordinator processes as needed."""
parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local
local_start_index = parallel_config.data_parallel_rank_local
dp_rank = parallel_config.data_parallel_rank
host = parallel_config.data_parallel_master_ip
local_engines_only = parallel_config.local_engines_only
offline_mode = local_start_index is not None
# Run the DP Coordinator process with rank 0 when in online DP mode. # Run the DP Coordinator process with rank 0 when in online DP mode.
# The coordinator is needed for: # The coordinator is needed for:
# 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing # 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
...@@ -885,6 +907,10 @@ def launch_core_engines( ...@@ -885,6 +907,10 @@ def launch_core_engines(
# will be False. # will be False.
handshake_local_only = offline_mode or local_engine_count == dp_size handshake_local_only = offline_mode or local_engine_count == dp_size
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
handshake_local_only = False
handshake_address = get_engine_client_zmq_addr( handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port handshake_local_only, host, parallel_config.data_parallel_rpc_port
) )
......
...@@ -38,6 +38,7 @@ from vllm.distributed.parallel_state import ( ...@@ -38,6 +38,7 @@ from vllm.distributed.parallel_state import (
get_pcp_group, get_pcp_group,
get_pp_group, get_pp_group,
get_tp_group, get_tp_group,
model_parallel_is_initialized,
) )
from vllm.envs import enable_envs_cache from vllm.envs import enable_envs_cache
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -580,16 +581,19 @@ class WorkerProc: ...@@ -580,16 +581,19 @@ class WorkerProc:
) )
self.async_output_copy_thread.start() self.async_output_copy_thread.start()
# Initialize device
self.worker.init_device()
# Set process title and log prefix
self.setup_proc_title_and_log_prefix( self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel enable_ep=vllm_config.parallel_config.enable_expert_parallel
) )
# Load model # Load model
self._init_message_queues(input_shm_handle, vllm_config) self._init_message_queues(input_shm_handle, vllm_config)
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
if not is_eep_new_worker:
self.worker.init_device()
# Update process title now that parallel groups are initialized
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
self.worker.load_model() self.worker.load_model()
# Enable environment variable cache (e.g. assume no more # Enable environment variable cache (e.g. assume no more
...@@ -885,6 +889,13 @@ class WorkerProc: ...@@ -885,6 +889,13 @@ class WorkerProc:
@staticmethod @staticmethod
def setup_proc_title_and_log_prefix(enable_ep: bool) -> None: def setup_proc_title_and_log_prefix(enable_ep: bool) -> None:
# Check if parallel groups are initialized first
if not model_parallel_is_initialized():
# Parallel groups not yet initialized, use default process name
set_process_title(name="Worker")
decorate_logs("Worker")
return
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
dp_rank = get_dp_group().rank_in_group dp_rank = get_dp_group().rank_in_group
pp_size = get_pp_group().world_size pp_size = get_pp_group().world_size
......
...@@ -382,6 +382,8 @@ class RayDistributedExecutor(Executor): ...@@ -382,6 +382,8 @@ class RayDistributedExecutor(Executor):
all_kwargs.append(kwargs) all_kwargs.append(kwargs)
self.collective_rpc("init_worker", args=(all_kwargs,)) self.collective_rpc("init_worker", args=(all_kwargs,))
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
if not is_eep_new_worker:
self.collective_rpc("init_device") self.collective_rpc("init_device")
self.collective_rpc("load_model") self.collective_rpc("load_model")
......
...@@ -14,7 +14,6 @@ import vllm.envs as envs ...@@ -14,7 +14,6 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.serial_utils import run_method from vllm.v1.serial_utils import run_method
...@@ -43,7 +42,9 @@ class UniProcExecutor(Executor): ...@@ -43,7 +42,9 @@ class UniProcExecutor(Executor):
max_workers=1, thread_name_prefix="WorkerAsyncOutput" max_workers=1, thread_name_prefix="WorkerAsyncOutput"
) )
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
self.driver_worker.init_worker(all_kwargs=[kwargs]) self.driver_worker.init_worker(all_kwargs=[kwargs])
if not is_eep_new_worker:
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
...@@ -122,16 +123,6 @@ class UniProcExecutor(Executor): ...@@ -122,16 +123,6 @@ class UniProcExecutor(Executor):
# it's running. # it's running.
return return
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
self.driver_worker.reinitialize_distributed(reconfig_request)
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
self.shutdown()
def shutdown(self) -> None: def shutdown(self) -> None:
if worker := self.driver_worker: if worker := self.driver_worker:
worker.shutdown() worker.shutdown()
......
...@@ -53,7 +53,12 @@ class CPUModelRunner(GPUModelRunner): ...@@ -53,7 +53,12 @@ class CPUModelRunner(GPUModelRunner):
v.gpu = v.cpu v.gpu = v.cpu
@instrument(span_name="Loading (CPU)") @instrument(span_name="Loading (CPU)")
def load_model(self, eep_scale_up: bool = False) -> None: def load_model(self, load_dummy_weights: bool = False) -> None:
if load_dummy_weights:
raise ValueError(
"Loading dummy weights (needed for elastic EP scale-up) "
"Is not supported by the CPU Model Runner."
)
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config) self.model = get_model(vllm_config=self.vllm_config)
......
...@@ -461,6 +461,8 @@ class GPUModelRunner( ...@@ -461,6 +461,8 @@ class GPUModelRunner(
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
self.eplb_state: EplbState | None = None self.eplb_state: EplbState | None = None
# NOTE(yongji): flag to temporarily disable EPLB during scaling up/down
self.eep_eplb_suppressed = False
""" """
State of the expert parallelism load balancer. State of the expert parallelism load balancer.
...@@ -2702,7 +2704,7 @@ class GPUModelRunner( ...@@ -2702,7 +2704,7 @@ class GPUModelRunner(
""" """
Step for the EPLB (Expert Parallelism Load Balancing) state. Step for the EPLB (Expert Parallelism Load Balancing) state.
""" """
if not self.parallel_config.enable_eplb: if not self.parallel_config.enable_eplb or self.eep_eplb_suppressed:
return return
assert self.eplb_state is not None assert self.eplb_state is not None
...@@ -2714,6 +2716,23 @@ class GPUModelRunner( ...@@ -2714,6 +2716,23 @@ class GPUModelRunner(
log_stats=self.parallel_config.eplb_config.log_balancedness, log_stats=self.parallel_config.eplb_config.log_balancedness,
) )
def setup_eplb_from_mapping(
self,
expanded_physical_to_logical: torch.Tensor,
old_num_physical_experts: int,
) -> None:
model = self.get_model()
assert is_mixture_of_experts(model)
self.eplb_state = EplbState.from_mapping(
model=model,
model_config=self.model_config,
device=self.device,
parallel_config=self.parallel_config,
expanded_physical_to_logical=expanded_physical_to_logical,
num_valid_physical_experts=old_num_physical_experts,
)
def _pool( def _pool(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -4175,21 +4194,16 @@ class GPUModelRunner( ...@@ -4175,21 +4194,16 @@ class GPUModelRunner(
setattr(self, config_name, new_config) setattr(self, config_name, new_config)
@instrument(span_name="Loading (GPU)") @instrument(span_name="Loading (GPU)")
def load_model(self, eep_scale_up: bool = False) -> None: def load_model(self, load_dummy_weights: bool = False) -> None:
""" """
Args: Args:
eep_scale_up: the model loading is for elastic EP scale up. load_dummy_weights: load dummy weights instead of real weights.
""" """
logger.info_once( logger.info_once(
"Starting to load model %s...", "Starting to load model %s...",
self.model_config.model, self.model_config.model,
scope="global", scope="global",
) )
global_expert_loads, old_global_expert_indices_per_model, rank_mapping = (
EplbState.get_eep_state(self.parallel_config)
if eep_scale_up
else (None, None, None)
)
if self.parallel_config.enable_eplb: if self.parallel_config.enable_eplb:
self.eplb_state = EplbState(self.parallel_config, self.device) self.eplb_state = EplbState(self.parallel_config, self.device)
...@@ -4198,6 +4212,8 @@ class GPUModelRunner( ...@@ -4198,6 +4212,8 @@ class GPUModelRunner(
try: try:
with DeviceMemoryProfiler() as m: with DeviceMemoryProfiler() as m:
time_before_load = time.perf_counter() time_before_load = time.perf_counter()
if load_dummy_weights:
self.load_config.load_format = "dummy"
model_loader = get_model_loader(self.load_config) model_loader = get_model_loader(self.load_config)
self.model = model_loader.load_model( self.model = model_loader.load_model(
vllm_config=self.vllm_config, model_config=self.model_config vllm_config=self.vllm_config, model_config=self.model_config
...@@ -4214,6 +4230,9 @@ class GPUModelRunner( ...@@ -4214,6 +4230,9 @@ class GPUModelRunner(
and is_mixture_of_experts(self.drafter.model) and is_mixture_of_experts(self.drafter.model)
and self.parallel_config.enable_eplb and self.parallel_config.enable_eplb
): ):
assert not self.parallel_config.enable_elastic_ep, (
"Elastic EP is not supported with drafter model."
)
spec_config = self.vllm_config.speculative_config spec_config = self.vllm_config.speculative_config
assert spec_config is not None assert spec_config is not None
assert spec_config.draft_model_config is not None assert spec_config.draft_model_config is not None
...@@ -4221,17 +4240,6 @@ class GPUModelRunner( ...@@ -4221,17 +4240,6 @@ class GPUModelRunner(
"EPLB is enabled for drafter model %s.", "EPLB is enabled for drafter model %s.",
spec_config.draft_model_config.model, spec_config.draft_model_config.model,
) )
global_expert_load = (
global_expert_loads[eplb_models]
if global_expert_loads
else None
)
old_global_expert_indices = (
old_global_expert_indices_per_model[eplb_models]
if old_global_expert_indices_per_model
else None
)
if self.eplb_state is None: if self.eplb_state is None:
self.eplb_state = EplbState( self.eplb_state = EplbState(
self.parallel_config, self.device self.parallel_config, self.device
...@@ -4239,9 +4247,6 @@ class GPUModelRunner( ...@@ -4239,9 +4247,6 @@ class GPUModelRunner(
self.eplb_state.add_model( self.eplb_state.add_model(
self.drafter.model, self.drafter.model,
spec_config.draft_model_config, spec_config.draft_model_config,
global_expert_load,
old_global_expert_indices,
rank_mapping,
) )
eplb_models += 1 eplb_models += 1
...@@ -4283,6 +4288,7 @@ class GPUModelRunner( ...@@ -4283,6 +4288,7 @@ class GPUModelRunner(
time_after_load - time_before_load, time_after_load - time_before_load,
scope="local", scope="local",
) )
if not load_dummy_weights:
prepare_communication_buffer_for_model(self.model) prepare_communication_buffer_for_model(self.model)
if (drafter := getattr(self, "drafter", None)) and ( if (drafter := getattr(self, "drafter", None)) and (
drafter_model := getattr(drafter, "model", None) drafter_model := getattr(drafter, "model", None)
...@@ -4295,26 +4301,19 @@ class GPUModelRunner( ...@@ -4295,26 +4301,19 @@ class GPUModelRunner(
and mm_config.is_multimodal_pruning_enabled() and mm_config.is_multimodal_pruning_enabled()
) )
if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: if (
is_mixture_of_experts(self.model)
and self.parallel_config.enable_eplb
and not load_dummy_weights
):
logger.info_once("EPLB is enabled for model %s.", self.model_config.model) logger.info_once("EPLB is enabled for model %s.", self.model_config.model)
global_expert_load = (
global_expert_loads[eplb_models] if global_expert_loads else None
)
old_global_expert_indices = (
old_global_expert_indices_per_model[eplb_models]
if old_global_expert_indices_per_model
else None
)
assert self.eplb_state is not None assert self.eplb_state is not None
self.eplb_state.add_model( self.eplb_state.add_model(
self.model, self.model,
self.model_config, self.model_config,
global_expert_load,
old_global_expert_indices,
rank_mapping,
) )
if self.eplb_state.is_async: if self.eplb_state.is_async:
self.eplb_state.start_async_loop(rank_mapping=rank_mapping) self.eplb_state.start_async_loop()
if ( if (
self.vllm_config.compilation_config.mode self.vllm_config.compilation_config.mode
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment