Unverified Commit dea26833 authored by Itay Alroy's avatar Itay Alroy Committed by GitHub
Browse files

[1/N] Elastic EP Milestone 2 (#34861)


Signed-off-by: default avatarYongji Wu <wuyongji317@gmail.com>
Signed-off-by: default avatarItay Alroy <ialroy@nvidia.com>
Signed-off-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: default avatarRon Tourgeman <rtourgeman@nvidia.com>
Co-authored-by: default avatarYongji Wu <wuyongji317@gmail.com>
Co-authored-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: default avatarRon Tourgeman <rtourgeman@nvidia.com>
parent 90805ff4
......@@ -29,6 +29,15 @@ PauseMode = Literal["abort", "wait", "keep"]
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
EEP_NOTIFICATION_CALL_ID = -1
class EEPNotificationType(enum.Enum):
NEW_CORE_ENGINES_INIT_READY = "NEW_CORE_ENGINES_INIT_READY"
NEW_CORE_ENGINES_WEIGHTS_INIT_READY = "NEW_CORE_ENGINES_WEIGHTS_INIT_READY"
RECONFIGURE_FINISHED = "RECONFIGURE_FINISHED"
SHUTDOWN_COMPLETE = "SHUTDOWN_COMPLETE"
class FinishReason(enum.IntEnum):
"""
......@@ -235,6 +244,11 @@ class ReconfigureDistributedRequest(msgspec.Struct):
new_data_parallel_rank_local: int
new_data_parallel_master_ip: str
new_data_parallel_master_port: int
new_data_parallel_master_port_list: list[int]
new_stateless_world_group_port_list: list[list[int]]
new_stateless_dp_group_port_list: list[list[int]]
new_stateless_ep_group_port_list: list[list[int]]
new_stateless_eplb_group_port_list: list[list[int]]
class ReconfigureRankType(enum.IntEnum):
......
......@@ -20,6 +20,7 @@ from vllm.distributed.weight_transfer.base import (
)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
......@@ -647,7 +648,11 @@ class AsyncLLM(EngineClient):
engine_core = self.engine_core
output_processor = self.output_processor
log_stats = self.log_stats
logger_manager = self.logger_manager
# We use a mutable list for logger_manager so that it can be updated
# during elastic EP scaling (see scale_elastic_ep) without creating
# a circular reference via self.
self._logger_ref = [self.logger_manager]
logger_ref = self._logger_ref
renderer = self.renderer
chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
......@@ -691,8 +696,8 @@ class AsyncLLM(EngineClient):
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if logger_manager:
logger_manager.record(
if logger_ref[0]:
logger_ref[0].record(
engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
......@@ -976,17 +981,13 @@ class AsyncLLM(EngineClient):
new_data_parallel_size,
)
return
logger.info(
"Waiting for requests to drain before scaling up to %s engines...",
new_data_parallel_size,
)
await self.wait_for_requests_to_drain(drain_timeout)
logger.info(
"Requests have been drained, proceeding with scale to %s engines",
new_data_parallel_size,
)
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
if envs.VLLM_ELASTIC_EP_DRAIN_REQUESTS:
logger.info(
"VLLM_ELASTIC_EP_DRAIN_REQUESTS is set, "
"waiting for requests to drain before scaling"
)
await self.wait_for_requests_to_drain(drain_timeout)
# recreate stat loggers
if new_data_parallel_size > old_data_parallel_size and self.log_stats:
......@@ -999,6 +1000,18 @@ class AsyncLLM(EngineClient):
engine_idxs=list(range(new_data_parallel_size)),
custom_stat_loggers=None,
)
# Update the mutable ref so output_handler picks up the
# new logger without creating a circular reference via self.
if hasattr(self, "_logger_ref"):
self._logger_ref[0] = self.logger_manager
self.logger_manager.log_engine_initialized()
set_scaling_elastic_ep(True)
try:
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
finally:
set_scaling_elastic_ep(False)
@property
def is_running(self) -> bool:
......
......@@ -71,6 +71,9 @@ class DPCoordinator:
)
local_only_eng = dp_size == parallel_config.data_parallel_size_local
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
local_only_eng = False
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
......@@ -201,6 +204,7 @@ class DPCoordinatorProc:
poller = zmq.Poller()
poller.register(publish_front, zmq.POLLIN)
poller.register(publish_back, zmq.POLLIN)
poller.register(output_back, zmq.POLLIN)
last_publish_time = 0
while True:
......@@ -231,6 +235,22 @@ class DPCoordinatorProc:
events = dict(events)
wave_state_changed = False
if publish_back in events:
buffer = publish_back.recv()
if buffer == b"\x01":
# NOTE(yongji): newly started engine subscribed
# We need to send READY message here instead of receiving
# SCALE_ELASTIC_EP notification from engine core client
# as SCALE_ELASTIC_EP is only sent when
# new engines finished initialization.
# Subscription message, on the other hand, is sent
# by each engine during initialization
publish_back.send(b"READY")
else:
logger.error(
"DP Coordinator receives unexpected message from engines"
)
if publish_front in events:
buffer = publish_front.recv()
if buffer in (b"\x01", b"\x00"):
......@@ -259,7 +279,6 @@ class DPCoordinatorProc:
# current_wave
# we note that 0 is the wave number for the new
# engine
engines_running = False
logger.info(
"DPCoordinator scaled up from %s to %s engines",
current_count,
......
......@@ -17,6 +17,7 @@ from typing import Any, TypeVar, cast
import msgspec
import zmq
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.envs import enable_envs_cache
......@@ -44,6 +45,8 @@ from vllm.v1.core.kv_cache_utils import (
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import (
EEP_NOTIFICATION_CALL_ID,
EEPNotificationType,
EngineCoreOutput,
EngineCoreOutputs,
EngineCoreRequest,
......@@ -110,6 +113,9 @@ class EngineCore:
self.available_gpu_memory_for_kv_cache = -1
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self._eep_scale_up_before_kv_init()
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
vllm_config
......@@ -233,12 +239,10 @@ class EngineCore:
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
if has_kv_cache:
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
dp_group = getattr(self, "dp_group", None)
assert dp_group is not None
self.available_gpu_memory_for_kv_cache = (
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
)
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
# NOTE(yongji): should already be set
# during _eep_scale_up_before_kv_init
assert self.available_gpu_memory_for_kv_cache > 0
available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
kv_cache_specs
)
......@@ -752,11 +756,22 @@ class EngineCore:
self.structured_output_manager.grammar_init(req)
return req, request.current_wave
def _eep_scale_up_before_kv_init(self):
raise NotImplementedError
def _eep_send_engine_core_notification(
self,
notification_type: EEPNotificationType,
vllm_config: VllmConfig | None = None,
):
raise NotImplementedError
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
addresses: EngineZmqAddresses
@instrument(span_name="EngineCoreProc init")
def __init__(
......@@ -807,6 +822,13 @@ class EngineCoreProc(EngineCore):
# and "hybrid" LB modes.
self.publish_dp_lb_stats = internal_dp_balancing
self.addresses = addresses
self.process_input_queue_block = True
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self._eep_send_engine_core_notification(
EEPNotificationType.NEW_CORE_ENGINES_INIT_READY,
vllm_config=vllm_config,
)
self._init_data_parallel(vllm_config)
super().__init__(
......@@ -1119,8 +1141,14 @@ class EngineCoreProc(EngineCore):
if logger.isEnabledFor(DEBUG):
logger.debug("EngineCore waiting for work.")
waited = True
req = self.input_queue.get()
self._handle_client_request(*req)
block = self.process_input_queue_block
try:
req = self.input_queue.get(block=block)
self._handle_client_request(*req)
except queue.Empty:
break
if not block:
break
if waited:
logger.debug("EngineCore loop active.")
......@@ -1290,6 +1318,11 @@ class EngineCoreProc(EngineCore):
for input_socket, _ in poller.poll():
# (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
# NOTE(yongji): ignore READY message sent by DP coordinator
# that is used to notify newly started engines
if type_frame.buffer == b"READY":
assert input_socket == coord_socket
continue
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data.
......@@ -1488,6 +1521,10 @@ class DPEngineCoreProc(EngineCoreProc):
self.current_wave = 0
self.last_counts = (0, 0)
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
self.eep_scaling_state: ElasticEPScalingState | None = None
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(
......@@ -1511,7 +1548,9 @@ class DPEngineCoreProc(EngineCoreProc):
assert 0 <= local_dp_rank <= dp_rank < dp_size
self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
self.dp_group, self.dp_store = (
vllm_config.parallel_config.stateless_init_dp_group(return_store=True)
)
def shutdown(self):
super().shutdown()
......@@ -1574,7 +1613,12 @@ class DPEngineCoreProc(EngineCoreProc):
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core.
if self.eep_scaling_state is not None:
_ = self.eep_scaling_state.progress()
if self.eep_scaling_state.is_complete():
self.process_input_queue_block = True
self.eep_scaling_state = None
executed = self._process_engine_step()
self._maybe_publish_request_counts()
......@@ -1624,54 +1668,129 @@ class DPEngineCoreProc(EngineCoreProc):
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
stateless_destroy_torch_distributed_process_group(self.dp_group)
self.shutdown()
parallel_config = self.vllm_config.parallel_config
old_dp_size = parallel_config.data_parallel_size
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if reconfig_request.new_data_parallel_rank != -1:
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
# local rank specifies device visibility, it should not be changed
assert (
reconfig_request.new_data_parallel_rank_local
== ReconfigureRankType.KEEP_CURRENT_RANK
)
parallel_config.data_parallel_master_ip = (
from copy import deepcopy
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
new_parallel_config = deepcopy(self.vllm_config.parallel_config)
old_dp_size = new_parallel_config.data_parallel_size
new_parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
new_parallel_config.data_parallel_rank = (
reconfig_request.new_data_parallel_rank
)
new_parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
new_parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
if reconfig_request.new_data_parallel_rank != -2:
self.dp_rank = parallel_config.data_parallel_rank
self.dp_group = parallel_config.stateless_init_dp_group()
reconfig_request.new_data_parallel_master_port = (
parallel_config.data_parallel_master_port
new_parallel_config._data_parallel_master_port_list = (
reconfig_request.new_data_parallel_master_port_list
)
self.model_executor.reinitialize_distributed(reconfig_request)
if reconfig_request.new_data_parallel_size > old_dp_size:
assert self.available_gpu_memory_for_kv_cache > 0
# pass available_gpu_memory_for_kv_cache from existing
# engine-cores to new engine-cores so they can directly
# use it in _initialize_kv_caches() rather than profiling.
ParallelConfig.sync_kv_cache_memory_size(
self.dp_group, self.available_gpu_memory_for_kv_cache
)
# NOTE(yongji): newly joined workers require dummy_run even
# CUDA graph is not used
self.model_executor.collective_rpc("compile_or_warm_up_model")
if (
is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size
is_shutdown = (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
self.shutdown()
logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
)
self.eep_scaling_state = ElasticEPScalingState(
model_executor=self.model_executor,
engine_core=self,
vllm_config=self.vllm_config,
new_parallel_config=new_parallel_config,
worker_type="removing" if is_shutdown else "existing",
scale_type="scale_down" if is_scale_down else "scale_up",
reconfig_request=reconfig_request,
)
self.process_input_queue_block = False
logger.info(
"[Elastic EP] Received reconfiguration request and starting scaling up/down"
)
def _eep_send_engine_core_notification(
self,
notification_type: EEPNotificationType,
vllm_config: VllmConfig | None = None,
):
"""
Send notifications to EngineCoreClient, which can then forward
the notifications to other engine core processes. It is used for:
1) In scale up: new core engines to notify exisiting core engines
that they are ready;
2) In scale down: removing core engines to notify EngineCoreClient
so EngineCoreClient can release their ray placement groups;
3) Both scale up/down: to notify EngineCoreClient that exisiting
core engines have already switched to the new parallel setup.
"""
if vllm_config is None:
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
else:
logger.info(
"Distributed environment reinitialized for DP rank %s", self.dp_rank
dp_rank = vllm_config.parallel_config.data_parallel_rank
notification_data = (notification_type.value, dp_rank)
outputs = EngineCoreOutputs(
utility_output=UtilityOutput(
call_id=EEP_NOTIFICATION_CALL_ID,
result=UtilityResult(notification_data),
)
)
outputs.engine_index = self.engine_index
if hasattr(self, "output_thread") and self.output_thread.is_alive():
self.output_queue.put_nowait((0, outputs))
else:
encoder = MsgpackEncoder()
with (
zmq.Context() as ctx,
make_zmq_socket(
ctx, self.addresses.outputs[0], zmq.PUSH, linger=4000
) as socket,
):
socket.send_multipart(encoder.encode(outputs))
def eep_handle_engine_core_notification(
self, notification_type: str | EEPNotificationType
):
"""
Handle notification received from EngineCoreClient
(forwarded from new core engines).
"""
assert self.eep_scaling_state is not None
if isinstance(notification_type, str):
notification_type = EEPNotificationType(notification_type)
self.eep_scaling_state.handle_notification(notification_type)
def _eep_scale_up_before_kv_init(self):
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
self.eep_scaling_state = ElasticEPScalingState(
model_executor=self.model_executor,
engine_core=self,
vllm_config=self.vllm_config,
new_parallel_config=self.vllm_config.parallel_config,
worker_type="new",
scale_type="scale_up",
reconfig_request=None,
)
self.model_executor.collective_rpc("init_device")
self.model_executor.collective_rpc("load_model")
self._eep_send_engine_core_notification(
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("receive_weights",)
)
self.available_gpu_memory_for_kv_cache = (
ParallelConfig.sync_kv_cache_memory_size(self.dp_group, -1)
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("prepare_new_worker",)
)
self.process_input_queue_block = False
class EngineCoreActorMixin:
......
......@@ -28,11 +28,12 @@ from vllm.tracing import instrument
from vllm.utils.async_utils import in_loop
from vllm.utils.network_utils import (
close_sockets,
get_open_port,
get_open_zmq_inproc_path,
make_zmq_socket,
)
from vllm.v1.engine import (
EEP_NOTIFICATION_CALL_ID,
EEPNotificationType,
EngineCoreOutputs,
EngineCoreRequest,
EngineCoreRequestType,
......@@ -47,6 +48,7 @@ from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.engine.utils import (
CoreEngineActorManager,
CoreEngineProcManager,
get_engine_zmq_addresses,
launch_core_engines,
)
from vllm.v1.executor import Executor
......@@ -445,6 +447,63 @@ class BackgroundResources:
raise EngineDeadError()
@dataclass
class ElasticScalingCache:
existing_core_engines: list[EngineIdentity]
num_new_core_engines: int
pending_notifications: dict[EEPNotificationType, set[int]]
def allocate_stateless_group_ports(parallel_config, new_data_parallel_size: int):
"""
Allocate stateless group ports for elastic EP.
"""
from vllm.utils.network_utils import get_open_ports_list
assert parallel_config.enable_elastic_ep, "Elastic EP must be enabled"
world_size = parallel_config.world_size
new_world_size_across_dp = world_size * new_data_parallel_size
num_world_groups = 1
num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size)
num_ep_groups = max(
1,
new_world_size_across_dp
// (new_data_parallel_size * parallel_config.tensor_parallel_size),
)
num_eplb_groups = num_ep_groups
total_ports_needed = (
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
) * 3 + 5
all_ports = get_open_ports_list(total_ports_needed)
new_data_parallel_master_port_list = all_ports[-5:]
all_ports = all_ports[:-5]
new_stateless_world_group_port_list = [
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
]
start_idx = num_world_groups * 3
new_stateless_dp_group_port_list = [
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
]
start_idx += num_dp_groups * 3
new_stateless_ep_group_port_list = [
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
]
start_idx += num_ep_groups * 3
new_stateless_eplb_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
]
parallel_config._stateless_world_group_port_list = (
new_stateless_world_group_port_list
)
parallel_config._stateless_dp_group_port_list = new_stateless_dp_group_port_list
parallel_config._stateless_ep_group_port_list = new_stateless_ep_group_port_list
parallel_config._stateless_eplb_group_port_list = new_stateless_eplb_group_port_list
parallel_config.data_parallel_master_port = new_data_parallel_master_port_list.pop()
parallel_config._data_parallel_master_port_list = new_data_parallel_master_port_list
class MPClient(EngineCoreClient):
"""
MPClient: base client for multi-proc EngineCore.
......@@ -491,32 +550,37 @@ class MPClient(EngineCoreClient):
input_address = client_addresses["input_address"]
output_address = client_addresses["output_address"]
self.stats_update_address = client_addresses.get("stats_update_address")
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True
)
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.PULL
)
else:
# Engines are managed by this client.
with launch_core_engines(vllm_config, executor_class, log_stats) as (
engine_manager,
coordinator,
addresses = get_engine_zmq_addresses(vllm_config)
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, addresses.inputs[0], zmq.ROUTER, bind=True
)
self.resources.output_socket = make_zmq_socket(
self.ctx, addresses.outputs[0], zmq.PULL
)
with launch_core_engines(
vllm_config,
executor_class,
log_stats,
addresses,
):
) as (engine_manager, coordinator, addresses):
self.resources.coordinator = coordinator
self.resources.engine_manager = engine_manager
(input_address,) = addresses.inputs
(output_address,) = addresses.outputs
self.stats_update_address = addresses.frontend_stats_publish_address
if coordinator is not None:
assert self.stats_update_address == (
coordinator.get_stats_publish_address()
)
# Create input and output sockets.
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True
)
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.PULL
)
parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_index
......@@ -877,6 +941,10 @@ class AsyncMPClient(MPClient):
output_socket = resources.output_socket
assert output_socket is not None
notification_callback_handler: (
Callable[[AsyncMPClient, Sequence[Any]], Any] | None
) = getattr(self.__class__, "eep_process_engine_core_notification", None)
async def process_outputs_socket():
try:
while True:
......@@ -884,7 +952,26 @@ class AsyncMPClient(MPClient):
resources.validate_alive(frames)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output, utility_results)
if (
outputs.utility_output.call_id == EEP_NOTIFICATION_CALL_ID
and notification_callback_handler is not None
):
assert _self_ref is not None
_self = _self_ref()
if not _self:
return
if outputs.utility_output.result is None:
continue
notification_data = outputs.utility_output.result.result
assert isinstance(notification_data, Sequence)
assert len(notification_data) == 2
asyncio.create_task(
notification_callback_handler(_self, notification_data)
)
else:
_process_utility_output(
outputs.utility_output, utility_results
)
continue
if output_handler is not None:
......@@ -1081,6 +1168,8 @@ class DPAsyncMPClient(AsyncMPClient):
# Used only by DPLBAsyncMPClient subclass.
self.lb_engines: list[list[int]] = [[0, 0] for _ in self.core_engines]
self.eep_scaling_cache: ElasticScalingCache | None = None
self.first_req_sock_addr = get_open_zmq_inproc_path()
self.first_req_send_socket = self.resources.first_req_send_socket = (
make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True)
......@@ -1101,12 +1190,6 @@ class DPAsyncMPClient(AsyncMPClient):
assert self.stats_update_address is not None
stats_addr: str = self.stats_update_address
assert len(self.engine_ranks_managed) > 0
# NOTE: running and waiting counts are all global from
# the Coordinator include all global EngineCores. This
# slice includes just the cores managed by this client.
count_slice = slice(
self.engine_ranks_managed[0], self.engine_ranks_managed[-1] + 1
)
async def run_engine_stats_update_task():
with (
......@@ -1145,6 +1228,29 @@ class DPAsyncMPClient(AsyncMPClient):
):
# Extract new engine count from the decoded message
new_engine_count = decoded[1]
# Update engine_ranks_managed and count_slice
parallel_config = self.vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
assert dp_rank == 0
assert dp_size == new_engine_count
assert not (
parallel_config.data_parallel_hybrid_lb
or parallel_config.data_parallel_external_lb
)
num_ranks = dp_size
self.engine_ranks_managed = list(
range(dp_rank, dp_rank + num_ranks)
)
if len(self.lb_engines) < new_engine_count:
self.lb_engines = self.lb_engines + [
[0, 0]
for _ in range(
new_engine_count - len(self.lb_engines)
)
]
else:
self.lb_engines = self.lb_engines[:new_engine_count]
# Send scale up notification to coordinator
scale_msg = msgspec.msgpack.encode(
("SCALE_ELASTIC_EP", new_engine_count)
......@@ -1178,6 +1284,11 @@ class DPAsyncMPClient(AsyncMPClient):
self.current_wave = wave
self.engines_running = running
if counts is not None:
# Running and waiting counts are global from the
# Coordinator including all EngineCores. Slice to get
# just the cores managed by this client.
ranks = self.engine_ranks_managed
count_slice = slice(ranks[0], ranks[-1] + 1)
sliced_counts = counts[count_slice]
self.lb_engines = sliced_counts
logger.debug(
......@@ -1287,6 +1398,67 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
for req_id in outputs.finished_requests:
self.reqs_in_flight.pop(req_id, None)
@staticmethod
async def eep_process_engine_core_notification(
self: "DPLBAsyncMPClient", notification_data: tuple[str, int]
):
cache = self.eep_scaling_cache
notification_type_str, dp_rank = notification_data
try:
notification_type = EEPNotificationType(notification_type_str)
except ValueError as e:
raise ValueError(
f"Unknown EEP notification type: {notification_type_str}"
) from e
if notification_type == EEPNotificationType.RECONFIGURE_FINISHED:
from vllm.v1.engine import UtilityResult
# NOTE(yongji): process a dummy UtilityOutput to resolve the future
# awaited in _eep_wait_for_setup_switch_complete(), signaling that
# all engine cores have completed reconfiguration.
dummy_output = UtilityOutput(
call_id=EEP_NOTIFICATION_CALL_ID, result=UtilityResult(None)
)
_process_utility_output(dummy_output, self.utility_results)
return
assert cache is not None
if notification_type not in cache.pending_notifications:
cache.pending_notifications[notification_type] = set()
if dp_rank in cache.pending_notifications[notification_type]:
raise ValueError(
f"Duplicate notification {notification_type} from dp_rank {dp_rank}"
)
cache.pending_notifications[notification_type].add(dp_rank)
if len(cache.pending_notifications[notification_type]) >= abs(
cache.num_new_core_engines
):
if notification_type == EEPNotificationType.SHUTDOWN_COMPLETE:
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
assert cache.num_new_core_engines < 0
old_dp_size = len(cache.existing_core_engines)
new_dp_size = old_dp_size + cache.num_new_core_engines
self.resources.engine_manager.scale_down_elastic_ep(
old_dp_size, new_dp_size
)
else:
await asyncio.gather(
*[
self._call_utility_async(
"eep_handle_engine_core_notification",
notification_type,
engine=engine,
)
for engine in cache.existing_core_engines
]
)
cache.pending_notifications[notification_type] = set()
if notification_type in [
EEPNotificationType.SHUTDOWN_COMPLETE,
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY,
]:
self.eep_scaling_cache = None
async def abort_requests_async(self, request_ids: list[str]) -> None:
if not request_ids or self.resources.engine_dead:
return
......@@ -1333,6 +1505,20 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
cur_data_parallel_size, new_data_parallel_size
)
async def _eep_wait_for_setup_switch_complete(self) -> None:
"""
Wait for core engines to switch to the new setup.
In eep_process_engine_core_notification(), a dummy UtilityOutput with
EEP_NOTIFICATION_CALL_ID will be set when RECONFIGURE_FINISHED
notification is received from engine 0. We create a future with
that call_id and wait for it to be resolved.
"""
future = asyncio.get_running_loop().create_future()
self.utility_results[EEP_NOTIFICATION_CALL_ID] = future
self._ensure_output_queue_task()
await future
async def _scale_up_elastic_ep(
self, cur_data_parallel_size: int, new_data_parallel_size: int
) -> None:
......@@ -1340,38 +1526,57 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
and reconfiguring existing ones."""
cur_data_parallel_size = len(self.core_engines)
# Phase 1: Send reconfigure messages to all existing engines and wait
# for them to be sent
self.eep_scaling_cache = ElasticScalingCache(
existing_core_engines=self.core_engines.copy(),
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
pending_notifications=dict(),
)
parallel_config = self.vllm_config.parallel_config
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
# Phase 1: Send reconfig messages to existing engines
reconfig_futures = []
self.vllm_config.parallel_config.data_parallel_master_port = get_open_port()
for engine in self.core_engines:
reconfig_request = ReconfigureDistributedRequest(
new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip,
new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port,
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
)
coro = self._call_utility_async(
"reinitialize_distributed", reconfig_request, engine=engine
)
reconfig_futures.append(asyncio.create_task(coro))
logger.info("All reconfigure messages sent, starting engine creation")
# Phase 2: Create new engines now that reconfig messages have been sent
# self.resources.engine_manager is guaranteed to be
# CoreEngineActorManager for RayDPClient
# Phase 2: Create new engines
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
self.resources.engine_manager.scale_up_elastic_ep(
self.vllm_config, new_data_parallel_size
parallel_config.eplb_config.num_redundant_experts = 0
start_new_worker_future = asyncio.to_thread(
self.resources.engine_manager.scale_up_elastic_ep,
self.vllm_config,
new_data_parallel_size,
)
wait_future = self._eep_wait_for_setup_switch_complete()
# Phase 3: Wait for new engines to be created
# and reconfig messages to be received
await asyncio.gather(start_new_worker_future, *reconfig_futures)
logger.info("[Elastic EP] Successfully started new engines")
# Create new CoreEngine objects for the new engines
new_engine_identities = set()
for i in range(cur_data_parallel_size, new_data_parallel_size):
new_engine = i.to_bytes(2, "little")
self.core_engines.append(new_engine)
# NOTE(yongji): we don't update lb_engines here,
# we let run_engine_stats_update_task to update it.
new_engine_identities.add(new_engine)
# Wait for ready messages from new engines on the input socket
......@@ -1387,10 +1592,11 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
identity, _ = sync_input_socket.recv_multipart()
new_engine_identities.discard(identity)
# Phase 3: Wait for all existing engines to complete reconfiguration
logger.info("Waiting for existing engines to complete reconfiguration")
await asyncio.gather(*reconfig_futures)
# NOTE(yongji): Before we schedule any requests on the new workers,
# we should wait for them to switch to the new setup.
await wait_future
# Update the parallel config
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
# Notify coordinator about scale up through existing
# stats_update_task connection
self._ensure_stats_update_task()
......@@ -1399,8 +1605,6 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
)
await self.first_req_send_socket.send(scale_up_marker)
# Update the parallel config
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
logger.info(
"[Elastic EP] Scale up completed, new data parallel size: %s",
new_data_parallel_size,
......@@ -1413,7 +1617,14 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
reconfiguring existing engine cores."""
cur_data_parallel_size = len(self.core_engines)
self.vllm_config.parallel_config.data_parallel_master_port = get_open_port()
self.eep_scaling_cache = ElasticScalingCache(
existing_core_engines=self.core_engines.copy(),
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
pending_notifications=dict(),
)
parallel_config = self.vllm_config.parallel_config
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
reconfig_futures = []
for cur_dp_rank, engine in enumerate(self.core_engines):
......@@ -1421,8 +1632,13 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip,
new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port,
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
)
if cur_dp_rank >= new_data_parallel_size:
reconfig_request.new_data_parallel_rank = (
......@@ -1433,23 +1649,24 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
)
reconfig_futures.append(asyncio.create_task(coro))
for _ in range(new_data_parallel_size, cur_data_parallel_size):
self.core_engines.pop()
# NOTE(yongji): Immediately stop sending requests to the removing engines.
self.core_engines = self.core_engines[:new_data_parallel_size]
self.lb_engines = self.lb_engines[:new_data_parallel_size]
wait_future = self._eep_wait_for_setup_switch_complete()
await asyncio.gather(*reconfig_futures)
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
self.resources.engine_manager.scale_down_elastic_ep(
cur_data_parallel_size, new_data_parallel_size
)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
self._ensure_stats_update_task()
scale_down_marker = msgspec.msgpack.encode(
("SCALE_ELASTIC_EP", new_data_parallel_size)
)
await self.first_req_send_socket.send(scale_down_marker)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
# NOTE(yongji): Unlike scaling up,
# here we don't actually need to wait for the setup switch to complete.
# We may want to remove it in the future.
await wait_future
logger.info(
"[Elastic EP] Scale down completed, new data parallel size: %s",
new_data_parallel_size,
......
......@@ -277,6 +277,8 @@ class CoreEngineActorManager:
else:
ray.init()
vllm_config.parallel_config.allocate_elastic_ep_ports()
if placement_groups is not None:
assert local_dp_ranks is not None, (
"local_dp_ranks must be provided if placement_groups is provided"
......@@ -584,6 +586,8 @@ class CoreEngineActorManager:
node_ip = node.node_ip
node_id = node.node_id
if device_str not in available_resources[node_id]:
continue
available_gpus = int(available_resources[node_id][device_str])
# Get total GPUs on this node from the node's resources
......@@ -773,26 +777,15 @@ class CoreEngineActorManager:
ray.util.remove_placement_group(pg)
@contextlib.contextmanager
def launch_core_engines(
def get_engine_zmq_addresses(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
num_api_servers: int = 1,
) -> Iterator[
tuple[
CoreEngineProcManager | CoreEngineActorManager | None,
DPCoordinator | None,
EngineZmqAddresses,
]
]:
"""Launch engine and DP coordinator processes as needed."""
) -> EngineZmqAddresses:
"""Allocate ZMQ addresses for engine-client communication."""
parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local
local_start_index = parallel_config.data_parallel_rank_local
dp_rank = parallel_config.data_parallel_rank
dp_size = parallel_config.data_parallel_size
host = parallel_config.data_parallel_master_ip
local_engines_only = parallel_config.local_engines_only
......@@ -806,9 +799,11 @@ def launch_core_engines(
client_local_only = (
offline_mode or local_engines_only or (local_engine_count == dp_size)
)
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
client_local_only = False
# Set up input and output addresses.
addresses = EngineZmqAddresses(
return EngineZmqAddresses(
inputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
......@@ -819,6 +814,33 @@ def launch_core_engines(
],
)
@contextlib.contextmanager
def launch_core_engines(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
addresses: EngineZmqAddresses,
num_api_servers: int = 1,
) -> Iterator[
tuple[
CoreEngineProcManager | CoreEngineActorManager | None,
DPCoordinator | None,
EngineZmqAddresses,
]
]:
"""Launch engine and DP coordinator processes as needed."""
parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local
local_start_index = parallel_config.data_parallel_rank_local
dp_rank = parallel_config.data_parallel_rank
host = parallel_config.data_parallel_master_ip
local_engines_only = parallel_config.local_engines_only
offline_mode = local_start_index is not None
# Run the DP Coordinator process with rank 0 when in online DP mode.
# The coordinator is needed for:
# 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
......@@ -885,6 +907,10 @@ def launch_core_engines(
# will be False.
handshake_local_only = offline_mode or local_engine_count == dp_size
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
handshake_local_only = False
handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port
)
......
......@@ -38,6 +38,7 @@ from vllm.distributed.parallel_state import (
get_pcp_group,
get_pp_group,
get_tp_group,
model_parallel_is_initialized,
)
from vllm.envs import enable_envs_cache
from vllm.logger import init_logger
......@@ -580,17 +581,20 @@ class WorkerProc:
)
self.async_output_copy_thread.start()
# Initialize device
self.worker.init_device()
# Set process title and log prefix
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
# Load model
self._init_message_queues(input_shm_handle, vllm_config)
self.worker.load_model()
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
if not is_eep_new_worker:
self.worker.init_device()
# Update process title now that parallel groups are initialized
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
self.worker.load_model()
# Enable environment variable cache (e.g. assume no more
# environment variable overrides after this point)
......@@ -885,6 +889,13 @@ class WorkerProc:
@staticmethod
def setup_proc_title_and_log_prefix(enable_ep: bool) -> None:
# Check if parallel groups are initialized first
if not model_parallel_is_initialized():
# Parallel groups not yet initialized, use default process name
set_process_title(name="Worker")
decorate_logs("Worker")
return
dp_size = get_dp_group().world_size
dp_rank = get_dp_group().rank_in_group
pp_size = get_pp_group().world_size
......
......@@ -382,8 +382,10 @@ class RayDistributedExecutor(Executor):
all_kwargs.append(kwargs)
self.collective_rpc("init_worker", args=(all_kwargs,))
self.collective_rpc("init_device")
self.collective_rpc("load_model")
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
if not is_eep_new_worker:
self.collective_rpc("init_device")
self.collective_rpc("load_model")
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
......
......@@ -14,7 +14,6 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.serial_utils import run_method
......@@ -43,9 +42,11 @@ class UniProcExecutor(Executor):
max_workers=1, thread_name_prefix="WorkerAsyncOutput"
)
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
self.driver_worker.init_worker(all_kwargs=[kwargs])
self.driver_worker.init_device()
self.driver_worker.load_model()
if not is_eep_new_worker:
self.driver_worker.init_device()
self.driver_worker.load_model()
def _distributed_args(self) -> tuple[str, int, int]:
"""Return (distributed_init_method, rank, local_rank)."""
......@@ -122,16 +123,6 @@ class UniProcExecutor(Executor):
# it's running.
return
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
self.driver_worker.reinitialize_distributed(reconfig_request)
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
self.shutdown()
def shutdown(self) -> None:
if worker := self.driver_worker:
worker.shutdown()
......
......@@ -53,7 +53,12 @@ class CPUModelRunner(GPUModelRunner):
v.gpu = v.cpu
@instrument(span_name="Loading (CPU)")
def load_model(self, eep_scale_up: bool = False) -> None:
def load_model(self, load_dummy_weights: bool = False) -> None:
if load_dummy_weights:
raise ValueError(
"Loading dummy weights (needed for elastic EP scale-up) "
"Is not supported by the CPU Model Runner."
)
logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config)
......
......@@ -461,6 +461,8 @@ class GPUModelRunner(
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
self.eplb_state: EplbState | None = None
# NOTE(yongji): flag to temporarily disable EPLB during scaling up/down
self.eep_eplb_suppressed = False
"""
State of the expert parallelism load balancer.
......@@ -2702,7 +2704,7 @@ class GPUModelRunner(
"""
Step for the EPLB (Expert Parallelism Load Balancing) state.
"""
if not self.parallel_config.enable_eplb:
if not self.parallel_config.enable_eplb or self.eep_eplb_suppressed:
return
assert self.eplb_state is not None
......@@ -2714,6 +2716,23 @@ class GPUModelRunner(
log_stats=self.parallel_config.eplb_config.log_balancedness,
)
def setup_eplb_from_mapping(
self,
expanded_physical_to_logical: torch.Tensor,
old_num_physical_experts: int,
) -> None:
model = self.get_model()
assert is_mixture_of_experts(model)
self.eplb_state = EplbState.from_mapping(
model=model,
model_config=self.model_config,
device=self.device,
parallel_config=self.parallel_config,
expanded_physical_to_logical=expanded_physical_to_logical,
num_valid_physical_experts=old_num_physical_experts,
)
def _pool(
self,
hidden_states: torch.Tensor,
......@@ -4175,21 +4194,16 @@ class GPUModelRunner(
setattr(self, config_name, new_config)
@instrument(span_name="Loading (GPU)")
def load_model(self, eep_scale_up: bool = False) -> None:
def load_model(self, load_dummy_weights: bool = False) -> None:
"""
Args:
eep_scale_up: the model loading is for elastic EP scale up.
load_dummy_weights: load dummy weights instead of real weights.
"""
logger.info_once(
"Starting to load model %s...",
self.model_config.model,
scope="global",
)
global_expert_loads, old_global_expert_indices_per_model, rank_mapping = (
EplbState.get_eep_state(self.parallel_config)
if eep_scale_up
else (None, None, None)
)
if self.parallel_config.enable_eplb:
self.eplb_state = EplbState(self.parallel_config, self.device)
......@@ -4198,6 +4212,8 @@ class GPUModelRunner(
try:
with DeviceMemoryProfiler() as m:
time_before_load = time.perf_counter()
if load_dummy_weights:
self.load_config.load_format = "dummy"
model_loader = get_model_loader(self.load_config)
self.model = model_loader.load_model(
vllm_config=self.vllm_config, model_config=self.model_config
......@@ -4214,6 +4230,9 @@ class GPUModelRunner(
and is_mixture_of_experts(self.drafter.model)
and self.parallel_config.enable_eplb
):
assert not self.parallel_config.enable_elastic_ep, (
"Elastic EP is not supported with drafter model."
)
spec_config = self.vllm_config.speculative_config
assert spec_config is not None
assert spec_config.draft_model_config is not None
......@@ -4221,17 +4240,6 @@ class GPUModelRunner(
"EPLB is enabled for drafter model %s.",
spec_config.draft_model_config.model,
)
global_expert_load = (
global_expert_loads[eplb_models]
if global_expert_loads
else None
)
old_global_expert_indices = (
old_global_expert_indices_per_model[eplb_models]
if old_global_expert_indices_per_model
else None
)
if self.eplb_state is None:
self.eplb_state = EplbState(
self.parallel_config, self.device
......@@ -4239,9 +4247,6 @@ class GPUModelRunner(
self.eplb_state.add_model(
self.drafter.model,
spec_config.draft_model_config,
global_expert_load,
old_global_expert_indices,
rank_mapping,
)
eplb_models += 1
......@@ -4283,11 +4288,12 @@ class GPUModelRunner(
time_after_load - time_before_load,
scope="local",
)
prepare_communication_buffer_for_model(self.model)
if (drafter := getattr(self, "drafter", None)) and (
drafter_model := getattr(drafter, "model", None)
):
prepare_communication_buffer_for_model(drafter_model)
if not load_dummy_weights:
prepare_communication_buffer_for_model(self.model)
if (drafter := getattr(self, "drafter", None)) and (
drafter_model := getattr(drafter, "model", None)
):
prepare_communication_buffer_for_model(drafter_model)
mm_config = self.model_config.multimodal_config
self.is_multimodal_pruning_enabled = (
supports_multimodal_pruning(self.get_model())
......@@ -4295,26 +4301,19 @@ class GPUModelRunner(
and mm_config.is_multimodal_pruning_enabled()
)
if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb:
if (
is_mixture_of_experts(self.model)
and self.parallel_config.enable_eplb
and not load_dummy_weights
):
logger.info_once("EPLB is enabled for model %s.", self.model_config.model)
global_expert_load = (
global_expert_loads[eplb_models] if global_expert_loads else None
)
old_global_expert_indices = (
old_global_expert_indices_per_model[eplb_models]
if old_global_expert_indices_per_model
else None
)
assert self.eplb_state is not None
self.eplb_state.add_model(
self.model,
self.model_config,
global_expert_load,
old_global_expert_indices,
rank_mapping,
)
if self.eplb_state.is_async:
self.eplb_state.start_async_loop(rank_mapping=rank_mapping)
self.eplb_state.start_async_loop()
if (
self.vllm_config.compilation_config.mode
......
......@@ -7,11 +7,10 @@ import os
from collections.abc import Callable
from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
......@@ -32,14 +31,12 @@ from vllm.distributed.kv_transfer import (
)
from vllm.distributed.parallel_state import (
Handle,
get_pcp_group,
get_pp_group,
get_tp_group,
)
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
......@@ -49,7 +46,6 @@ from vllm.tracing import instrument
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (
AsyncModelRunnerOutput,
......@@ -124,6 +120,10 @@ class Worker(WorkerBase):
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
torch.set_float32_matmul_precision(precision)
from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor
self.elastic_ep_executor = ElasticEPScalingExecutor(self)
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
......@@ -317,12 +317,29 @@ class Worker(WorkerBase):
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
if dummy_weights:
(
expanded_physical_to_logical,
num_logical_experts,
old_num_physical_experts,
) = self.elastic_ep_executor.receive_expert_mapping()
num_physical_experts = expanded_physical_to_logical.shape[1]
self.parallel_config.eplb_config.num_redundant_experts = (
num_physical_experts - num_logical_experts
)
with (
self._maybe_get_memory_pool_context(tag="weights"),
set_current_vllm_config(self.vllm_config),
):
self.model_runner.load_model(eep_scale_up=eep_scale_up)
self.model_runner.load_model(load_dummy_weights=dummy_weights)
if dummy_weights:
self.model_runner.setup_eplb_from_mapping(
expanded_physical_to_logical, old_num_physical_experts
)
self.model_runner.eep_eplb_suppressed = True
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
......@@ -801,227 +818,6 @@ class Worker(WorkerBase):
# worker will always be healthy as long as it's running.
return
def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info(
"[Elastic EP] Starting expert resharding before scaling down..."
)
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
execute_shuffle=True,
global_expert_loads=None,
rank_mapping=rank_mapping,
)
torch.cuda.synchronize()
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _eplb_after_scale_up(
self,
old_ep_size: int,
new_ep_size: int,
global_expert_loads: list[torch.Tensor] | None,
) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding after scaling up...")
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
execute_shuffle=True,
global_expert_loads=global_expert_loads,
rank_mapping=rank_mapping,
)
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _reconfigure_parallel_config(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
"""
Update parallel config with provided reconfig_request
"""
parallel_config = self.vllm_config.parallel_config
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
def _reconfigure_moe(
self, old_ep_size: int, new_ep_size: int
) -> list[torch.Tensor] | None:
"""
Reconfigure MoE modules with provided reconfig_request
Return the global expert load if new_ep_size > old_ep_size,
otherwise None
"""
from vllm.distributed.parallel_state import (
get_dp_group,
get_ep_group,
prepare_communication_buffer_for_model,
)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEParallelConfig,
)
parallel_config = self.vllm_config.parallel_config
def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
return [
module
for module in model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
assert all(
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules
), "All MoE modules must have the same number of experts"
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
tp_size = get_tp_group().world_size
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
sp_size = tp_size if is_sequence_parallel else 1
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=tp_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
sp_size_=sp_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
return moe_modules
model_moe_modules = get_moe_modules(self.model_runner.model)
num_local_experts = model_moe_modules[0].moe_config.num_local_experts
update_moe_modules(model_moe_modules, num_local_experts)
drafter_model = None
if hasattr(self.model_runner, "drafter") and hasattr(
self.model_runner.drafter, "model"
):
drafter_model = self.model_runner.drafter.model
if drafter_model is not None and is_mixture_of_experts(drafter_model):
drafter_moe_modules = get_moe_modules(drafter_model)
# Check if drafter and model have matching configs
assert (
drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
), "Drafter and model configs should be the same"
update_moe_modules(drafter_moe_modules, num_local_experts)
if new_ep_size < old_ep_size:
num_local_physical_experts = num_local_experts
assert self.model_runner.eplb_state is not None
new_physical_experts = (
self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined]
)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts
- self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined]
)
global_expert_loads = None
else:
num_local_physical_experts_tensor = torch.tensor(
[num_local_experts], dtype=torch.int32, device="cpu"
)
torch.distributed.broadcast(
num_local_physical_experts_tensor,
group=get_ep_group().cpu_group,
group_src=0,
)
num_local_physical_experts = int(num_local_physical_experts_tensor.item())
new_physical_experts = num_local_physical_experts * new_ep_size
assert self.model_runner.eplb_state is not None
global_expert_loads_any = self.model_runner.eplb_state.rearrange(
execute_shuffle=False
)
global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts - global_expert_loads[0].shape[1]
)
prepare_communication_buffer_for_model(self.model_runner.model)
if drafter_model is not None:
prepare_communication_buffer_for_model(drafter_model)
self.model_runner.model.update_physical_experts_metadata(
num_physical_experts=new_physical_experts,
num_local_physical_experts=num_local_physical_experts,
)
return global_expert_loads
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (
cleanup_dist_env_and_memory,
get_ep_group,
)
old_ep_size = get_ep_group().world_size
old_ep_rank = get_ep_group().rank
new_ep_size = (
reconfig_request.new_data_parallel_size
* get_tp_group().world_size
* get_pp_group().world_size
)
if new_ep_size < old_ep_size:
self._eplb_before_scale_down(old_ep_size, new_ep_size)
cleanup_dist_env_and_memory()
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
assert old_ep_rank >= new_ep_size
# shutdown
return
self._reconfigure_parallel_config(reconfig_request)
with set_current_vllm_config(self.vllm_config):
init_worker_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
)
global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
if new_ep_size > old_ep_size:
assert global_expert_loads is not None
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
def save_sharded_state(
self,
path: str,
......@@ -1118,6 +914,9 @@ class Worker(WorkerBase):
if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
weight_transfer_engine.shutdown()
def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
return self.elastic_ep_executor.execute(execute_method, *args, **kwargs)
def init_worker_distributed_environment(
vllm_config: VllmConfig,
......
......@@ -66,6 +66,23 @@ class WorkspaceManager:
],
)
def unlock(self) -> None:
"""Unlock the workspace to allow growth.
This is used during elastic EP scaling when the workspace size
needs to grow due to changes in the number of experts.
"""
self._locked = False
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(
"[WORKSPACE DEBUG] Workspace unlocked. Current sizes: %s",
[
self._workspace_size_bytes(ws) / _MB
for ws in self._current_workspaces
if ws is not None
],
)
def is_locked(self) -> bool:
"""Check if workspace is locked."""
return self._locked
......@@ -242,6 +259,17 @@ def lock_workspace() -> None:
current_workspace_manager().lock()
def unlock_workspace() -> None:
"""Unlock the workspace to allow growth.
This is used during elastic EP scaling when the workspace size
needs to grow due to changes in the number of experts.
After scaling operations complete, lock_workspace() should be
called again to prevent unexpected allocations.
"""
current_workspace_manager().unlock()
def reset_workspace_manager() -> None:
"""Reset the workspace manager to uninitialized state.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment