Unverified Commit bd877162 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Support online dense model DP without overhead (#30739)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
Signed-off-by: default avatarnjhill <nickhill123@gmail.com>
parent 08f425ba
...@@ -205,8 +205,8 @@ def test_default_pooling_type(model_id, default_pooling_type, pooling_type): ...@@ -205,8 +205,8 @@ def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
) )
def test_moe_model_detection(model_id, expected_is_moe_model): def test_moe_model_detection(model_id, expected_is_moe_model):
model_config = ModelConfig(model_id) model_config = ModelConfig(model_id)
# Just check that is_moe_model field exists and is a boolean # Just check that is_moe field exists and is a boolean
assert model_config.is_model_moe() == expected_is_moe_model assert model_config.is_moe == expected_is_moe_model
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -224,7 +224,7 @@ def test_moe_model_detection(model_id, expected_is_moe_model): ...@@ -224,7 +224,7 @@ def test_moe_model_detection(model_id, expected_is_moe_model):
def test_is_quantized(model_id, quantized): def test_is_quantized(model_id, quantized):
model_config = ModelConfig(model_id) model_config = ModelConfig(model_id)
# Just check that quantized field exists and is a boolean # Just check that quantized field exists and is a boolean
assert model_config.is_quantized() == quantized assert model_config.is_quantized == quantized
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -925,7 +925,7 @@ def test_vllm_config_callable_defaults(): ...@@ -925,7 +925,7 @@ def test_vllm_config_callable_defaults():
model_config=quantized_model, optimization_level=OptimizationLevel.O2 model_config=quantized_model, optimization_level=OptimizationLevel.O2
) )
enable_if_quantized = lambda cfg: ( enable_if_quantized = lambda cfg: (
cfg.model_config is not None and cfg.model_config.is_quantized() cfg.model_config is not None and cfg.model_config.is_quantized
) )
assert enable_if_quantized(config_quantized) is True assert enable_if_quantized(config_quantized) is True
assert enable_if_quantized(config_no_model) is False assert enable_if_quantized(config_no_model) is False
...@@ -936,7 +936,7 @@ def test_vllm_config_callable_defaults(): ...@@ -936,7 +936,7 @@ def test_vllm_config_callable_defaults():
model_config=moe_model, optimization_level=OptimizationLevel.O2 model_config=moe_model, optimization_level=OptimizationLevel.O2
) )
enable_if_sequential = lambda cfg: ( enable_if_sequential = lambda cfg: (
cfg.model_config is not None and not cfg.model_config.is_model_moe() cfg.model_config is not None and not cfg.model_config.is_moe
) )
assert enable_if_sequential(config_moe) is False assert enable_if_sequential(config_moe) is False
assert enable_if_sequential(config_quantized) is True assert enable_if_sequential(config_quantized) is True
...@@ -1050,3 +1050,46 @@ def test_scheduler_config_init(): ...@@ -1050,3 +1050,46 @@ def test_scheduler_config_init():
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
# InitVar does not become an attribute # InitVar does not become an attribute
print(SchedulerConfig.default_factory().max_model_len) print(SchedulerConfig.default_factory().max_model_len)
@pytest.mark.parametrize(
(
"model_id",
"data_parallel_size",
"external_lb",
"expected_needs_coordinator",
),
[
# Non-MoE model with DP=1 should not need coordinator
("facebook/opt-125m", 1, False, False),
# Non-MoE model with DP>1 internal LB should need coordinator
("facebook/opt-125m", 2, False, True),
# Non-MoE model with DP>1 external LB should not need coordinator
("facebook/opt-125m", 2, True, False),
# MoE model with DP=1 should not need coordinator
("mistralai/Mixtral-8x7B-Instruct-v0.1", 1, False, False),
# MoE model with DP>1 internal LB should need both coordinator
# and wave coordination
("mistralai/Mixtral-8x7B-Instruct-v0.1", 2, False, True),
# MoE model with DP>1 external LB needs coordinator for wave coordination
# (wave coordination runs in coordinator process)
("mistralai/Mixtral-8x7B-Instruct-v0.1", 2, True, True),
],
)
def test_needs_dp_coordination(
model_id,
data_parallel_size,
external_lb,
expected_needs_coordinator,
):
"""Test that DP coordinator and wave coordination are configured correctly."""
from vllm.config import ParallelConfig
model_config = ModelConfig(model_id)
parallel_config = ParallelConfig(
data_parallel_size=data_parallel_size,
data_parallel_external_lb=external_lb,
)
vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config)
assert vllm_config.needs_dp_coordinator == expected_needs_coordinator
...@@ -133,6 +133,7 @@ def test_mp_client_uses_env_timeout(monkeypatch: pytest.MonkeyPatch): ...@@ -133,6 +133,7 @@ def test_mp_client_uses_env_timeout(monkeypatch: pytest.MonkeyPatch):
parallel_config = SimpleNamespace( parallel_config = SimpleNamespace(
data_parallel_size=1, data_parallel_size=1,
data_parallel_rank=0, data_parallel_rank=0,
data_parallel_index=0,
data_parallel_size_local=1, data_parallel_size_local=1,
data_parallel_rank_local=None, data_parallel_rank_local=None,
data_parallel_hybrid_lb=False, data_parallel_hybrid_lb=False,
......
...@@ -630,7 +630,7 @@ class VllmBackend: ...@@ -630,7 +630,7 @@ class VllmBackend:
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
self.compilation_config.cache_dir = cache_dir self.compilation_config.cache_dir = cache_dir
rank = vllm_config.parallel_config.rank rank = vllm_config.parallel_config.rank
dp_rank = vllm_config.parallel_config.data_parallel_rank dp_rank = vllm_config.parallel_config.data_parallel_index
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix) local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
os.makedirs(local_cache_dir, exist_ok=True) os.makedirs(local_cache_dir, exist_ok=True)
self.compilation_config.local_cache_dir = local_cache_dir self.compilation_config.local_cache_dir = local_cache_dir
......
...@@ -403,7 +403,7 @@ def _support_torch_compile( ...@@ -403,7 +403,7 @@ def _support_torch_compile(
) )
rank = self.vllm_config.parallel_config.rank rank = self.vllm_config.parallel_config.rank
dp_rank = self.vllm_config.parallel_config.data_parallel_rank dp_rank = self.vllm_config.parallel_config.data_parallel_index
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
aot_compilation_path = os.path.join(cache_dir, "model") aot_compilation_path = os.path.join(cache_dir, "model")
try: try:
......
...@@ -642,7 +642,7 @@ class ModelConfig: ...@@ -642,7 +642,7 @@ class ModelConfig:
cls = "Transformers" cls = "Transformers"
# If 'hf_config != hf_text_config' it's a nested config, i.e. multimodal # If 'hf_config != hf_text_config' it's a nested config, i.e. multimodal
cls += "MultiModal" if self.hf_config != self.hf_text_config else "" cls += "MultiModal" if self.hf_config != self.hf_text_config else ""
cls += "MoE" if self.get_num_experts() > 1 else "" cls += "MoE" if self.is_moe else ""
# Check if the architecture we're wrapping has defaults # Check if the architecture we're wrapping has defaults
runner = None runner = None
task = None task = None
...@@ -1001,8 +1001,7 @@ class ModelConfig: ...@@ -1001,8 +1001,7 @@ class ModelConfig:
self.enforce_eager = True self.enforce_eager = True
def _verify_with_expert_parallelism(self) -> None: def _verify_with_expert_parallelism(self) -> None:
num_experts = self.get_num_experts() if not self.is_moe:
if num_experts < 1:
raise ValueError( raise ValueError(
"Number of experts in the model must be greater than 0 " "Number of experts in the model must be greater than 0 "
"when expert parallelism is enabled." "when expert parallelism is enabled."
...@@ -1797,11 +1796,11 @@ class ModelConfig: ...@@ -1797,11 +1796,11 @@ class ModelConfig:
logger.debug("Generative models support prefix caching.") logger.debug("Generative models support prefix caching.")
return True return True
def is_model_moe( @property
self, def is_moe(self) -> bool:
) -> bool: return self.get_num_experts() > 0
return self.get_num_experts() > 1
@property
def is_quantized(self) -> bool: def is_quantized(self) -> bool:
return getattr(self.hf_config, "quantization_config", None) is not None return getattr(self.hf_config, "quantization_config", None) is not None
......
...@@ -119,6 +119,8 @@ class ParallelConfig: ...@@ -119,6 +119,8 @@ class ParallelConfig:
between local data parallel ranks, but an external LB balances between local data parallel ranks, but an external LB balances
between vLLM nodes/replicas. Set explicitly in conjunction with between vLLM nodes/replicas. Set explicitly in conjunction with
--data-parallel-start-rank.""" --data-parallel-start-rank."""
is_moe_model: bool | None = None
"""Whether the deployed model is MoE (if known)."""
enable_expert_parallel: bool = False enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers.""" """Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False enable_eplb: bool = False
...@@ -255,6 +257,10 @@ class ParallelConfig: ...@@ -255,6 +257,10 @@ class ParallelConfig:
Block_size should be divisible by cp_kv_cache_interleave_size. Block_size should be divisible by cp_kv_cache_interleave_size.
""" """
data_parallel_index: int = Field(init=False)
"""Equal to the data parallel rank but not used for torch process groups
and not overridden for dense models."""
_api_process_count: int = Field(default=1, gt=0) _api_process_count: int = Field(default=1, gt=0)
""" """
The number of API processes initialized. The number of API processes initialized.
...@@ -466,6 +472,7 @@ class ParallelConfig: ...@@ -466,6 +472,7 @@ class ParallelConfig:
"data_parallel_rank", "data_parallel_rank",
"data_parallel_rank_local", "data_parallel_rank_local",
"data_parallel_size_local", "data_parallel_size_local",
"data_parallel_index",
"data_parallel_backend", "data_parallel_backend",
"data_parallel_external_lb", "data_parallel_external_lb",
"data_parallel_hybrid_lb", "data_parallel_hybrid_lb",
...@@ -546,6 +553,14 @@ class ParallelConfig: ...@@ -546,6 +553,14 @@ class ParallelConfig:
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
if self.data_parallel_size > 1 and self.is_moe_model is False:
raise ValueError(
"Offline data parallel mode is not supported/useful"
" for dense models."
)
self.data_parallel_index = self.data_parallel_rank
if self.distributed_executor_backend == "external_launcher": if self.distributed_executor_backend == "external_launcher":
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.") logger.info("Disabling V1 multiprocessing for external launcher.")
......
...@@ -343,6 +343,29 @@ class VllmConfig: ...@@ -343,6 +343,29 @@ class VllmConfig:
# i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size # i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size
return self.compilation_config.bs_to_padded_graph_size[batch_size] return self.compilation_config.bs_to_padded_graph_size[batch_size]
@property
def needs_dp_coordinator(self) -> bool:
"""
Determine if the DPCoordinator process is needed.
The DPCoordinator is needed in two cases:
1. For MoE models with DP > 1: to handle wave coordination
(even in external LB mode, since wave coordination runs in the coordinator)
2. For non-MoE models in internal/hybrid LB mode: to collect and publish
queue stats for load balancing across DP ranks
Returns:
True if DPCoordinator process is needed, False otherwise.
"""
# For non-MoE models, only need coordinator in internal/hybrid LB mode
# (for stats collection).
return self.parallel_config.data_parallel_size > 1 and (
self.model_config is None
or self.model_config.is_moe
or not self.parallel_config.data_parallel_external_lb
)
def enable_trace_function_call_for_thread(self) -> None: def enable_trace_function_call_for_thread(self) -> None:
""" """
Set up function tracing for the current thread, Set up function tracing for the current thread,
...@@ -522,6 +545,8 @@ class VllmConfig: ...@@ -522,6 +545,8 @@ class VllmConfig:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
self.model_config.verify_dual_chunk_attention_config(self.load_config) self.model_config.verify_dual_chunk_attention_config(self.load_config)
self.parallel_config.is_moe_model = self.model_config.is_moe
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.lora_config is not None: if self.lora_config is not None:
...@@ -827,9 +852,14 @@ class VllmConfig: ...@@ -827,9 +852,14 @@ class VllmConfig:
) )
# Do this after all the updates to compilation_config.mode # Do this after all the updates to compilation_config.mode
effective_dp_size = (
self.parallel_config.data_parallel_size
if self.model_config is None or self.model_config.is_moe
else 1
)
self.compilation_config.set_splitting_ops_for_v1( self.compilation_config.set_splitting_ops_for_v1(
all2all_backend=self.parallel_config.all2all_backend, all2all_backend=self.parallel_config.all2all_backend,
data_parallel_size=self.parallel_config.data_parallel_size, data_parallel_size=effective_dp_size,
) )
if self.compilation_config.pass_config.enable_sp: if self.compilation_config.pass_config.enable_sp:
...@@ -1297,13 +1327,8 @@ class VllmConfig: ...@@ -1297,13 +1327,8 @@ class VllmConfig:
if self.compilation_config.debug_dump_path is None: if self.compilation_config.debug_dump_path is None:
return None return None
tp_rank = self.parallel_config.rank tp_rank = self.parallel_config.rank
dp_rank = self.parallel_config.data_parallel_rank dp_rank = self.parallel_config.data_parallel_index
data_parallel_size = self.parallel_config.data_parallel_size append_path = f"rank_{tp_rank}_dp_{dp_rank}"
append_path = (
f"rank_{tp_rank}"
if data_parallel_size == 1
else f"rank_{tp_rank}_dp_{dp_rank}"
)
path = self.compilation_config.debug_dump_path / append_path path = self.compilation_config.debug_dump_path / append_path
return path return path
......
...@@ -915,6 +915,6 @@ def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int: ...@@ -915,6 +915,6 @@ def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int:
# This logic is now centralized # This logic is now centralized
return ( return (
envs.VLLM_MOONCAKE_BOOTSTRAP_PORT envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
+ vllm_config.parallel_config.data_parallel_rank + vllm_config.parallel_config.data_parallel_index
* vllm_config.parallel_config.tensor_parallel_size * vllm_config.parallel_config.tensor_parallel_size
) )
...@@ -471,7 +471,7 @@ class NixlConnectorScheduler: ...@@ -471,7 +471,7 @@ class NixlConnectorScheduler:
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
self.side_channel_port = ( self.side_channel_port = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT envs.VLLM_NIXL_SIDE_CHANNEL_PORT
+ vllm_config.parallel_config.data_parallel_rank + vllm_config.parallel_config.data_parallel_index
) )
assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config is not None
if current_platform.device_type == "cpu": if current_platform.device_type == "cpu":
......
...@@ -1115,7 +1115,11 @@ _EP: GroupCoordinator | None = None ...@@ -1115,7 +1115,11 @@ _EP: GroupCoordinator | None = None
def get_ep_group() -> GroupCoordinator: def get_ep_group() -> GroupCoordinator:
assert _EP is not None, "expert parallel group is not initialized" assert _EP is not None, (
"expert parallel group is not initialized. "
"EP group is only created for MoE models with num_experts > 0. "
"This function should only be called for MoE models."
)
return _EP return _EP
...@@ -1400,20 +1404,23 @@ def initialize_model_parallel( ...@@ -1400,20 +1404,23 @@ def initialize_model_parallel(
global _EP global _EP
assert _EP is None, "expert parallel group is already initialized" assert _EP is None, "expert parallel group is already initialized"
group_ranks = ( # Don't create EP group for dense models.
all_ranks.transpose(1, 2) if config is None or config.model_config is None or config.model_config.is_moe:
.reshape( group_ranks = (
-1, all_ranks.transpose(1, 2)
data_parallel_size .reshape(
* prefill_context_model_parallel_size -1,
* tensor_model_parallel_size, data_parallel_size
* prefill_context_model_parallel_size
* tensor_model_parallel_size,
)
.unbind(0)
) )
.unbind(0) group_ranks = [x.tolist() for x in group_ranks]
) _EP = init_model_parallel_group(
group_ranks = [x.tolist() for x in group_ranks] group_ranks, get_world_group().local_rank, backend, group_name="ep"
_EP = init_model_parallel_group( )
group_ranks, get_world_group().local_rank, backend, group_name="ep" # If no EP group needed, _EP remains None
)
logger.info_once( logger.info_once(
"rank %s in world size %s is assigned as " "rank %s in world size %s is assigned as "
...@@ -1425,7 +1432,7 @@ def initialize_model_parallel( ...@@ -1425,7 +1432,7 @@ def initialize_model_parallel(
_PP.rank_in_group, _PP.rank_in_group,
_PCP.rank_in_group, _PCP.rank_in_group,
_TP.rank_in_group, _TP.rank_in_group,
_EP.rank_in_group, _EP.rank_in_group if _EP is not None else "N/A",
) )
......
...@@ -1575,6 +1575,7 @@ class EngineArgs: ...@@ -1575,6 +1575,7 @@ class EngineArgs:
data_parallel_rpc_port=data_parallel_rpc_port, data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=self.data_parallel_backend, data_parallel_backend=self.data_parallel_backend,
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
is_moe_model=model_config.is_moe,
enable_expert_parallel=self.enable_expert_parallel, enable_expert_parallel=self.enable_expert_parallel,
all2all_backend=self.all2all_backend, all2all_backend=self.all2all_backend,
enable_dbo=self.enable_dbo, enable_dbo=self.enable_dbo,
......
...@@ -102,6 +102,7 @@ class DPMetadata: ...@@ -102,6 +102,7 @@ class DPMetadata:
) -> "DPMetadata": ) -> "DPMetadata":
assert num_tokens_across_dp_cpu is not None assert num_tokens_across_dp_cpu is not None
assert parallel_config.data_parallel_size > 1 assert parallel_config.data_parallel_size > 1
assert parallel_config.is_moe_model is not False
dp_rank = parallel_config.data_parallel_rank dp_rank = parallel_config.data_parallel_rank
batchsize = num_tokens batchsize = num_tokens
......
...@@ -127,7 +127,7 @@ class Scheduler(SchedulerInterface): ...@@ -127,7 +127,7 @@ class Scheduler(SchedulerInterface):
self.kv_event_publisher = EventPublisherFactory.create( self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config, self.kv_events_config,
self.parallel_config.data_parallel_rank, self.parallel_config.data_parallel_index,
) )
self.ec_connector = None self.ec_connector = None
if self.vllm_config.ec_transfer_config is not None: if self.vllm_config.ec_transfer_config is not None:
......
...@@ -55,7 +55,9 @@ class DPCoordinator: ...@@ -55,7 +55,9 @@ class DPCoordinator:
request wave / running state changes. request wave / running state changes.
""" """
def __init__(self, parallel_config: ParallelConfig): def __init__(
self, parallel_config: ParallelConfig, enable_wave_coordination: bool = True
):
dp_size = parallel_config.data_parallel_size dp_size = parallel_config.data_parallel_size
assert dp_size > 1, "Coordinator only used for data parallel" assert dp_size > 1, "Coordinator only used for data parallel"
...@@ -83,6 +85,7 @@ class DPCoordinator: ...@@ -83,6 +85,7 @@ class DPCoordinator:
"front_publish_address": front_publish_address, "front_publish_address": front_publish_address,
"back_output_address": back_output_address, "back_output_address": back_output_address,
"back_publish_address": back_publish_address, "back_publish_address": back_publish_address,
"enable_wave_coordination": enable_wave_coordination,
}, },
daemon=True, daemon=True,
) )
...@@ -110,13 +113,19 @@ class EngineState: ...@@ -110,13 +113,19 @@ class EngineState:
class DPCoordinatorProc: class DPCoordinatorProc:
def __init__(self, engine_count: int, min_stats_update_interval_ms: int = 100): def __init__(
self,
engine_count: int,
min_stats_update_interval_ms: int = 100,
enable_wave_coordination: bool = True,
):
set_process_title("DPCoordinator") set_process_title("DPCoordinator")
self.ctx = zmq.Context() self.ctx = zmq.Context()
self.engines = [EngineState() for _ in range(engine_count)] self.engines = [EngineState() for _ in range(engine_count)]
self.stats_update_interval_ms = min_stats_update_interval_ms self.stats_update_interval_ms = min_stats_update_interval_ms
self.enable_wave_coordination = enable_wave_coordination
@staticmethod @staticmethod
def run_coordinator( def run_coordinator(
...@@ -125,10 +134,12 @@ class DPCoordinatorProc: ...@@ -125,10 +134,12 @@ class DPCoordinatorProc:
back_output_address: str, back_output_address: str,
back_publish_address: str, back_publish_address: str,
min_stats_update_interval_ms: int = 100, min_stats_update_interval_ms: int = 100,
enable_wave_coordination: bool = True,
): ):
coordinator = DPCoordinatorProc( coordinator = DPCoordinatorProc(
engine_count=engine_count, engine_count=engine_count,
min_stats_update_interval_ms=min_stats_update_interval_ms, min_stats_update_interval_ms=min_stats_update_interval_ms,
enable_wave_coordination=enable_wave_coordination,
) )
try: try:
coordinator.process_input_socket( coordinator.process_input_socket(
...@@ -265,22 +276,25 @@ class DPCoordinatorProc: ...@@ -265,22 +276,25 @@ class DPCoordinatorProc:
) )
continue # Skip normal engine notification processing continue # Skip normal engine notification processing
# We received a message on the front-end XPUB socket, # Wave coordination: handle new-request messages from front-end.
# from an API server sending a new request while the # Only process these when wave coordination is enabled
# engines are paused, so that we can wake the other if self.enable_wave_coordination:
# engines. # We received a message on the front-end XPUB socket,
engine_to_exclude, wave = decoded # from an API server sending a new request while the
if not engines_running: # engines are paused, so that we can wake the other
if wave < current_wave: # engines.
# If the wave number is stale, ensure the message engine_to_exclude, wave = decoded
# is handled by all the engines. if not engines_running:
engine_to_exclude = None if wave < current_wave:
# If the wave number is stale, ensure the message
engines_running = True # is handled by all the engines.
wave_state_changed = True engine_to_exclude = None
self._send_start_wave(
publish_back, current_wave, engine_to_exclude engines_running = True
) wave_state_changed = True
self._send_start_wave(
publish_back, current_wave, engine_to_exclude
)
if output_back in events: if output_back in events:
# We received a message from one of the engines. # We received a message from one of the engines.
...@@ -325,34 +339,39 @@ class DPCoordinatorProc: ...@@ -325,34 +339,39 @@ class DPCoordinatorProc:
stats[1] = scheduler_stats.num_running_reqs stats[1] = scheduler_stats.num_running_reqs
stats_changed = True stats_changed = True
if (wave := outputs.wave_complete) is not None: # Wave coordination: handle wave completion and start notifications
# 2. Notification from rank 0 engine that we've # Only process these when wave coordination is enabled
# moved into the global paused state if self.enable_wave_coordination:
# (engines_running==False). if (wave := outputs.wave_complete) is not None:
if current_wave <= wave: # 2. Notification from rank 0 engine that we've
new_wave = wave + 1 # moved into the global paused state
# (engines_running==False).
if current_wave <= wave:
new_wave = wave + 1
logger.debug(
"Moving DP wave from %d to %d.",
current_wave,
new_wave,
)
current_wave = new_wave
engines_running = False
wave_state_changed = True
elif (wave := outputs.start_wave) is not None and (
wave > current_wave
or (wave == current_wave and not engines_running)
):
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# next wave (race condition handling).
logger.debug( logger.debug(
"Moving DP wave from %d to %d.", current_wave, new_wave "Starting wave %d after notification of "
"stale wave request from engine.",
wave,
) )
current_wave = new_wave current_wave = wave
engines_running = False engines_running = True
wave_state_changed = True wave_state_changed = True
elif (wave := outputs.start_wave) is not None and ( self._send_start_wave(publish_back, wave, eng_index)
wave > current_wave
or (wave == current_wave and not engines_running)
):
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# next wave (race condition handling).
logger.debug(
"Starting wave %d after notification of "
"stale wave request from engine.",
wave,
)
current_wave = wave
engines_running = True
wave_state_changed = True
self._send_start_wave(publish_back, wave, eng_index)
if wave_state_changed: if wave_state_changed:
message = (None, current_wave, engines_running) message = (None, current_wave, engines_running)
......
...@@ -84,6 +84,7 @@ class EngineCore: ...@@ -84,6 +84,7 @@ class EngineCore:
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
executor_fail_callback: Callable | None = None, executor_fail_callback: Callable | None = None,
include_finished_set: bool = False,
): ):
# plugins need to be loaded at the engine/scheduler level too # plugins need to be loaded at the engine/scheduler level too
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
...@@ -91,7 +92,7 @@ class EngineCore: ...@@ -91,7 +92,7 @@ class EngineCore:
load_general_plugins() load_general_plugins()
self.vllm_config = vllm_config self.vllm_config = vllm_config
if vllm_config.parallel_config.data_parallel_rank == 0: if not vllm_config.parallel_config.data_parallel_rank_local:
logger.info( logger.info(
"Initializing a V1 LLM engine (v%s) with config: %s", "Initializing a V1 LLM engine (v%s) with config: %s",
VLLM_VERSION, VLLM_VERSION,
...@@ -138,7 +139,7 @@ class EngineCore: ...@@ -138,7 +139,7 @@ class EngineCore:
vllm_config=vllm_config, vllm_config=vllm_config,
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
structured_output_manager=self.structured_output_manager, structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size > 1, include_finished_set=include_finished_set,
log_stats=self.log_stats, log_stats=self.log_stats,
block_size=scheduler_block_size, block_size=scheduler_block_size,
) )
...@@ -605,6 +606,7 @@ class EngineCoreProc(EngineCore): ...@@ -605,6 +606,7 @@ class EngineCoreProc(EngineCore):
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
client_handshake_address: str | None = None, client_handshake_address: str | None = None,
*,
engine_index: int = 0, engine_index: int = 0,
): ):
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
...@@ -636,17 +638,22 @@ class EngineCoreProc(EngineCore): ...@@ -636,17 +638,22 @@ class EngineCoreProc(EngineCore):
self.has_coordinator, self.has_coordinator,
self.frontend_stats_publish_address, self.frontend_stats_publish_address,
) )
# Only publish request queue stats to coordinator for "internal" internal_dp_balancing = (
# and "hybrid" LB modes .
self.publish_dp_lb_stats = (
self.has_coordinator self.has_coordinator
and not vllm_config.parallel_config.data_parallel_external_lb and not vllm_config.parallel_config.data_parallel_external_lb
) )
# Only publish request queue stats to coordinator for "internal"
# and "hybrid" LB modes.
self.publish_dp_lb_stats = internal_dp_balancing
self._init_data_parallel(vllm_config) self._init_data_parallel(vllm_config)
super().__init__( super().__init__(
vllm_config, executor_class, log_stats, executor_fail_callback vllm_config,
executor_class,
log_stats,
executor_fail_callback,
internal_dp_balancing,
) )
# Background Threads and Queues for IO. These enable us to # Background Threads and Queues for IO. These enable us to
...@@ -854,18 +861,29 @@ class EngineCoreProc(EngineCore): ...@@ -854,18 +861,29 @@ class EngineCoreProc(EngineCore):
engine_core: EngineCoreProc | None = None engine_core: EngineCoreProc | None = None
try: try:
parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config vllm_config: VllmConfig = kwargs["vllm_config"]
if parallel_config.data_parallel_size > 1 or dp_rank > 0: parallel_config: ParallelConfig = vllm_config.parallel_config
data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
if data_parallel:
parallel_config.data_parallel_rank_local = local_dp_rank
set_process_title("EngineCore", f"DP{dp_rank}") set_process_title("EngineCore", f"DP{dp_rank}")
decorate_logs() else:
set_process_title("EngineCore")
decorate_logs()
parallel_config.data_parallel_index = dp_rank
if data_parallel and vllm_config.model_config.is_moe:
# Set data parallel rank for this engine process. # Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank
engine_core = DPEngineCoreProc(*args, **kwargs) engine_core = DPEngineCoreProc(*args, **kwargs)
else: else:
set_process_title("EngineCore") # Non-MoE DP ranks are completely independent, so treat like DP=1.
decorate_logs() # Note that parallel_config.data_parallel_index will still reflect
engine_core = EngineCoreProc(*args, **kwargs) # the original DP rank.
parallel_config.data_parallel_size = 1
parallel_config.data_parallel_size_local = 1
parallel_config.data_parallel_rank = 0
engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
engine_core.run_busy_loop() engine_core.run_busy_loop()
...@@ -1195,6 +1213,10 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1195,6 +1213,10 @@ class DPEngineCoreProc(EngineCoreProc):
log_stats: bool, log_stats: bool,
client_handshake_address: str | None = None, client_handshake_address: str | None = None,
): ):
assert vllm_config.model_config.is_moe, (
"DPEngineCoreProc should only be used for MoE models"
)
# Counts forward-passes of the model so that we can synchronize # Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps. # finished with DP peers every N steps.
self.step_counter = 0 self.step_counter = 0
...@@ -1210,7 +1232,7 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1210,7 +1232,7 @@ class DPEngineCoreProc(EngineCoreProc):
executor_class, executor_class,
log_stats, log_stats,
client_handshake_address, client_handshake_address,
dp_rank, engine_index=dp_rank,
) )
def _init_data_parallel(self, vllm_config: VllmConfig): def _init_data_parallel(self, vllm_config: VllmConfig):
...@@ -1391,7 +1413,7 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1391,7 +1413,7 @@ class DPEngineCoreProc(EngineCoreProc):
) )
class DPEngineCoreActor(DPEngineCoreProc): class EngineCoreActorMixin:
""" """
Ray actor for running EngineCore in a data parallel context Ray actor for running EngineCore in a data parallel context
""" """
...@@ -1399,15 +1421,12 @@ class DPEngineCoreActor(DPEngineCoreProc): ...@@ -1399,15 +1421,12 @@ class DPEngineCoreActor(DPEngineCoreProc):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
local_client: bool,
addresses: EngineZmqAddresses, addresses: EngineZmqAddresses,
executor_class: type[Executor],
log_stats: bool,
dp_rank: int = 0, dp_rank: int = 0,
local_dp_rank: int = 0, local_dp_rank: int = 0,
): ):
self.addresses = addresses self.addresses = addresses
vllm_config.parallel_config.data_parallel_rank = dp_rank vllm_config.parallel_config.data_parallel_index = dp_rank
vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
# Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
...@@ -1429,8 +1448,6 @@ class DPEngineCoreActor(DPEngineCoreProc): ...@@ -1429,8 +1448,6 @@ class DPEngineCoreActor(DPEngineCoreProc):
# of ray. # of ray.
self._set_visible_devices(vllm_config, local_dp_rank) self._set_visible_devices(vllm_config, local_dp_rank)
super().__init__(vllm_config, local_client, "", executor_class, log_stats)
def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int): def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -1491,7 +1508,7 @@ class DPEngineCoreActor(DPEngineCoreProc): ...@@ -1491,7 +1508,7 @@ class DPEngineCoreActor(DPEngineCoreProc):
Run the engine core busy loop. Run the engine core busy loop.
""" """
try: try:
self.run_busy_loop() self.run_busy_loop() # type: ignore[attr-defined]
except SystemExit: except SystemExit:
logger.debug("EngineCore exiting.") logger.debug("EngineCore exiting.")
raise raise
...@@ -1499,4 +1516,58 @@ class DPEngineCoreActor(DPEngineCoreProc): ...@@ -1499,4 +1516,58 @@ class DPEngineCoreActor(DPEngineCoreProc):
logger.exception("EngineCore encountered a fatal error.") logger.exception("EngineCore encountered a fatal error.")
raise raise
finally: finally:
self.shutdown() self.shutdown() # type: ignore[attr-defined]
class DPMoEEngineCoreActor(EngineCoreActorMixin, DPEngineCoreProc):
"""Used for MoE model data parallel cases."""
def __init__(
self,
vllm_config: VllmConfig,
local_client: bool,
addresses: EngineZmqAddresses,
executor_class: type[Executor],
log_stats: bool,
dp_rank: int = 0,
local_dp_rank: int = 0,
):
vllm_config.parallel_config.data_parallel_rank = dp_rank
EngineCoreActorMixin.__init__(
self, vllm_config, addresses, dp_rank, local_dp_rank
)
DPEngineCoreProc.__init__(
self, vllm_config, local_client, "", executor_class, log_stats
)
class EngineCoreActor(EngineCoreActorMixin, EngineCoreProc):
"""Used for non-MoE and/or non-DP cases."""
def __init__(
self,
vllm_config: VllmConfig,
local_client: bool,
addresses: EngineZmqAddresses,
executor_class: type[Executor],
log_stats: bool,
dp_rank: int = 0,
local_dp_rank: int = 0,
):
vllm_config.parallel_config.data_parallel_size = 1
vllm_config.parallel_config.data_parallel_size_local = 1
vllm_config.parallel_config.data_parallel_rank = 0
EngineCoreActorMixin.__init__(
self, vllm_config, addresses, dp_rank, local_dp_rank
)
EngineCoreProc.__init__(
self,
vllm_config,
local_client,
"",
executor_class,
log_stats,
engine_index=dp_rank,
)
...@@ -502,7 +502,7 @@ class MPClient(EngineCoreClient): ...@@ -502,7 +502,7 @@ class MPClient(EngineCoreClient):
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank dp_rank = parallel_config.data_parallel_index
dp_local_size = parallel_config.data_parallel_size_local dp_local_size = parallel_config.data_parallel_size_local
offline_mode = parallel_config.data_parallel_rank_local is not None offline_mode = parallel_config.data_parallel_rank_local is not None
# Client manages local+remote EngineCores in pure internal LB case. # Client manages local+remote EngineCores in pure internal LB case.
......
...@@ -65,8 +65,9 @@ class LLMEngine: ...@@ -65,8 +65,9 @@ class LLMEngine:
self.log_stats = log_stats self.log_stats = log_stats
executor_backend = self.vllm_config.parallel_config.distributed_executor_backend
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
executor_backend = parallel_config.distributed_executor_backend
self.external_launcher_dp = ( self.external_launcher_dp = (
parallel_config.data_parallel_size > 1 parallel_config.data_parallel_size > 1
and executor_backend == "external_launcher" and executor_backend == "external_launcher"
......
...@@ -75,7 +75,6 @@ class EngineHandshakeMetadata: ...@@ -75,7 +75,6 @@ class EngineHandshakeMetadata:
addresses: EngineZmqAddresses addresses: EngineZmqAddresses
parallel_config: dict[str, int | str | list[int]] parallel_config: dict[str, int | str | list[int]]
parallel_config_hash: str | None = None
class CoreEngineProcManager: class CoreEngineProcManager:
...@@ -249,12 +248,19 @@ class CoreEngineActorManager: ...@@ -249,12 +248,19 @@ class CoreEngineActorManager:
from ray.runtime_env import RuntimeEnv from ray.runtime_env import RuntimeEnv
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from vllm.v1.engine.core import DPEngineCoreActor from vllm.v1.engine.core import DPMoEEngineCoreActor, EngineCoreActor
dp_size = vllm_config.parallel_config.data_parallel_size
actor_class = (
DPMoEEngineCoreActor
if dp_size > 1 and vllm_config.model_config.is_moe
else EngineCoreActor
)
self.local_engine_actors: list[ray.ActorHandle] = [] self.local_engine_actors: list[ray.ActorHandle] = []
self.remote_engine_actors: list[ray.ActorHandle] = [] self.remote_engine_actors: list[ray.ActorHandle] = []
env_vars_list = get_env_vars_to_copy(destination="DPEngineCoreActor") env_vars_list = get_env_vars_to_copy(destination=actor_class.__name__)
self.env_vars_dict = { self.env_vars_dict = {
name: os.environ[name] for name in env_vars_list if name in os.environ name: os.environ[name] for name in env_vars_list if name in os.environ
} }
...@@ -263,7 +269,6 @@ class CoreEngineActorManager: ...@@ -263,7 +269,6 @@ class CoreEngineActorManager:
self.addresses = addresses self.addresses = addresses
self.executor_class = executor_class self.executor_class = executor_class
self.log_stats = log_stats self.log_stats = log_stats
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = vllm_config.parallel_config.data_parallel_size_local local_engine_count = vllm_config.parallel_config.data_parallel_size_local
world_size = vllm_config.parallel_config.world_size world_size = vllm_config.parallel_config.world_size
...@@ -314,7 +319,7 @@ class CoreEngineActorManager: ...@@ -314,7 +319,7 @@ class CoreEngineActorManager:
runtime_env = RuntimeEnv(env_vars=actor_env_vars) runtime_env = RuntimeEnv(env_vars=actor_env_vars)
actor = ( actor = (
ray.remote(DPEngineCoreActor) ray.remote(actor_class)
.options( .options(
scheduling_strategy=PlacementGroupSchedulingStrategy( scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group=pg,
...@@ -624,7 +629,13 @@ class CoreEngineActorManager: ...@@ -624,7 +629,13 @@ class CoreEngineActorManager:
from ray.runtime_env import RuntimeEnv from ray.runtime_env import RuntimeEnv
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from vllm.v1.engine.core import DPEngineCoreActor from vllm.v1.engine.core import DPMoEEngineCoreActor, EngineCoreActor
actor_class = (
DPMoEEngineCoreActor
if cur_vllm_config.model_config.is_moe
else EngineCoreActor
)
cur_data_parallel_size = len(self.local_engine_actors) + len( cur_data_parallel_size = len(self.local_engine_actors) + len(
self.remote_engine_actors self.remote_engine_actors
...@@ -667,7 +678,7 @@ class CoreEngineActorManager: ...@@ -667,7 +678,7 @@ class CoreEngineActorManager:
) )
actor = ( actor = (
ray.remote(DPEngineCoreActor) ray.remote(actor_class)
.options( .options(
scheduling_strategy=PlacementGroupSchedulingStrategy( scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group=pg,
...@@ -804,12 +815,19 @@ def launch_core_engines( ...@@ -804,12 +815,19 @@ def launch_core_engines(
], ],
) )
# Run the DP Coordinator process with rank 0 when in # Run the DP Coordinator process with rank 0 when in online DP mode.
# online DP mode. # The coordinator is needed for:
run_coordinator = dp_size > 1 and not offline_mode and dp_rank == 0 # 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
# 2. MoE models: wave coordination in addition to stats
run_coordinator = (
vllm_config.needs_dp_coordinator and not offline_mode and dp_rank == 0
)
if run_coordinator: if run_coordinator:
coordinator = DPCoordinator(parallel_config) coordinator = DPCoordinator(
parallel_config,
enable_wave_coordination=vllm_config.model_config.is_moe,
)
addresses.coordinator_input, addresses.coordinator_output = ( addresses.coordinator_input, addresses.coordinator_output = (
coordinator.get_engine_socket_addresses() coordinator.get_engine_socket_addresses()
...@@ -905,6 +923,7 @@ def launch_core_engines( ...@@ -905,6 +923,7 @@ def launch_core_engines(
addresses, addresses,
engines_to_handshake, engines_to_handshake,
parallel_config, parallel_config,
dp_size > 1 and vllm_config.model_config.is_moe,
vllm_config.cache_config, vllm_config.cache_config,
local_engine_manager, local_engine_manager,
coordinator.proc if coordinator else None, coordinator.proc if coordinator else None,
...@@ -916,6 +935,7 @@ def wait_for_engine_startup( ...@@ -916,6 +935,7 @@ def wait_for_engine_startup(
addresses: EngineZmqAddresses, addresses: EngineZmqAddresses,
core_engines: list[CoreEngine], core_engines: list[CoreEngine],
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
coordinated_dp: bool,
cache_config: CacheConfig, cache_config: CacheConfig,
proc_manager: CoreEngineProcManager | None, proc_manager: CoreEngineProcManager | None,
coord_process: Process | None, coord_process: Process | None,
...@@ -997,8 +1017,7 @@ def wait_for_engine_startup( ...@@ -997,8 +1017,7 @@ def wait_for_engine_startup(
) )
if status == "HELLO" and engine.state == CoreEngineState.NEW: if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info and config hash. # Send init message with DP config info.
# The config hash ensures all DP workers have compatible configs.
init_message = msgspec.msgpack.encode( init_message = msgspec.msgpack.encode(
EngineHandshakeMetadata( EngineHandshakeMetadata(
addresses=addresses, addresses=addresses,
...@@ -1010,10 +1029,9 @@ def wait_for_engine_startup( ...@@ -1010,10 +1029,9 @@ def wait_for_engine_startup(
"_data_parallel_master_port_list", "_data_parallel_master_port_list",
"data_parallel_size", "data_parallel_size",
) )
}, }
parallel_config_hash=parallel_config.compute_hash() if coordinated_dp
if parallel_config.data_parallel_size > 1 else {},
else None,
) )
) )
handshake_socket.send_multipart((eng_identity, init_message), copy=False) handshake_socket.send_multipart((eng_identity, init_message), copy=False)
...@@ -1034,8 +1052,8 @@ def wait_for_engine_startup( ...@@ -1034,8 +1052,8 @@ def wait_for_engine_startup(
if addresses.frontend_stats_publish_address is None: if addresses.frontend_stats_publish_address is None:
addresses.frontend_stats_publish_address = msg.get("dp_stats_address") addresses.frontend_stats_publish_address = msg.get("dp_stats_address")
# Validate config hash consistency across DP workers # Validate config hash consistency across DP workers for MoE models.
if parallel_config.data_parallel_size > 1: if coordinated_dp:
worker_config_hash = msg.get("parallel_config_hash") worker_config_hash = msg.get("parallel_config_hash")
expected_hash = parallel_config.compute_hash() expected_hash = parallel_config.compute_hash()
if worker_config_hash != expected_hash: if worker_config_hash != expected_hash:
......
...@@ -98,9 +98,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -98,9 +98,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.max_num_reqs = self.scheduler_config.max_num_seqs self.max_num_reqs = self.scheduler_config.max_num_seqs
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size() self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
self.dp_size = self.parallel_config.data_parallel_size
self.dp_rank = self.parallel_config.data_parallel_rank
self.use_async_scheduling = self.scheduler_config.async_scheduling self.use_async_scheduling = self.scheduler_config.async_scheduling
self.output_copy_stream = torch.cuda.Stream(self.device) self.output_copy_stream = torch.cuda.Stream(self.device)
self.output_copy_event = torch.cuda.Event() self.output_copy_event = torch.cuda.Event()
...@@ -268,7 +265,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -268,7 +265,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if not skip_attn: if not skip_attn:
self.prepare_dummy_attn_metadata(input_batch) self.prepare_dummy_attn_metadata(input_batch)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) dp_size = self.parallel_config.data_parallel_size
num_tokens_across_dp = make_num_tokens_across_dp(dp_size, num_tokens)
num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32) num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
with ( with (
self.maybe_dummy_run_with_lora( self.maybe_dummy_run_with_lora(
...@@ -312,7 +310,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -312,7 +310,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self._dummy_sampler_run(sample_hidden_states) self._dummy_sampler_run(sample_hidden_states)
if self.do_spec_decode: if self.do_spec_decode:
num_tokens_across_dp = make_num_tokens_across_dp( num_tokens_across_dp = make_num_tokens_across_dp(
self.dp_size, self.max_num_tokens self.parallel_config.data_parallel_size, self.max_num_tokens
) )
self.speculator.run_model( self.speculator.run_model(
self.max_num_tokens, self.max_num_tokens,
...@@ -807,7 +805,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -807,7 +805,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]: ) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.dp_size == 1: dp_size = self.parallel_config.data_parallel_size
if dp_size == 1:
# No DP. Only consider CUDA graphs. # No DP. Only consider CUDA graphs.
if total_num_scheduled_tokens == 0: if total_num_scheduled_tokens == 0:
# Special case: no tokens to run. # Special case: no tokens to run.
...@@ -835,11 +834,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -835,11 +834,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_size_before_dp = -1 cudagraph_size_before_dp = -1
assert cudagraph_size_before_dp is not None assert cudagraph_size_before_dp is not None
dp_rank = self.parallel_config.data_parallel_rank
num_tokens_across_dp, cudagraph_size_across_dp = get_batch_metadata_across_dp( num_tokens_across_dp, cudagraph_size_across_dp = get_batch_metadata_across_dp(
total_num_scheduled_tokens, total_num_scheduled_tokens,
cudagraph_size_before_dp, cudagraph_size_before_dp,
self.dp_size, dp_size,
self.dp_rank, dp_rank,
) )
if all(cudagraph_size_across_dp >= 0): if all(cudagraph_size_across_dp >= 0):
# If all ranks can use CUDA graph, pad to the maximum number of tokens # If all ranks can use CUDA graph, pad to the maximum number of tokens
...@@ -850,7 +850,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -850,7 +850,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# If any of the ranks cannot use CUDA graph, use eager mode for all ranks. # If any of the ranks cannot use CUDA graph, use eager mode for all ranks.
# No padding is needed except for ranks that have no tokens to run. # No padding is needed except for ranks that have no tokens to run.
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1) num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
num_tokens_after_padding = num_tokens_across_dp[self.dp_rank] num_tokens_after_padding = num_tokens_across_dp[dp_rank]
cudagraph_mode = CUDAGraphMode.NONE cudagraph_mode = CUDAGraphMode.NONE
return cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp return cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp
......
...@@ -179,22 +179,20 @@ class Worker(WorkerBase): ...@@ -179,22 +179,20 @@ class Worker(WorkerBase):
self.cache_config.num_cpu_blocks = num_cpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks
def init_device(self): def init_device(self):
device = self.device_config.device if self.device_config.device_type == "cuda":
if isinstance(device, torch.device) and device.type == "cuda":
# This env var set by Ray causes exceptions with graph building. # This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
parallel_config = self.parallel_config
if ( if (
self.parallel_config.data_parallel_size > 1 parallel_config.distributed_executor_backend
and self.parallel_config.data_parallel_size_local > 0 not in ("ray", "external_launcher")
and self.parallel_config.distributed_executor_backend and parallel_config.data_parallel_backend != "ray"
not in ["ray", "external_launcher"] and parallel_config.nnodes_within_dp == 1
and self.vllm_config.parallel_config.data_parallel_backend != "ray"
and self.vllm_config.parallel_config.nnodes_within_dp == 1
): ):
# Use local DP rank if available, otherwise use global DP rank. # Use local DP rank if available, otherwise use global DP rank.
dp_local_rank = self.parallel_config.data_parallel_rank_local dp_local_rank = self.parallel_config.data_parallel_rank_local
if dp_local_rank is None: if dp_local_rank is None:
dp_local_rank = self.parallel_config.data_parallel_rank dp_local_rank = self.parallel_config.data_parallel_index
tp_pp_world_size = ( tp_pp_world_size = (
self.parallel_config.pipeline_parallel_size self.parallel_config.pipeline_parallel_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment