Unverified Commit dea26833 authored by Itay Alroy's avatar Itay Alroy Committed by GitHub
Browse files

[1/N] Elastic EP Milestone 2 (#34861)


Signed-off-by: default avatarYongji Wu <wuyongji317@gmail.com>
Signed-off-by: default avatarItay Alroy <ialroy@nvidia.com>
Signed-off-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: default avatarRon Tourgeman <rtourgeman@nvidia.com>
Co-authored-by: default avatarYongji Wu <wuyongji317@gmail.com>
Co-authored-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: default avatarRon Tourgeman <rtourgeman@nvidia.com>
parent 90805ff4
...@@ -21,3 +21,18 @@ steps: ...@@ -21,3 +21,18 @@ steps:
commands: commands:
- pytest -v -s distributed/test_eplb_execute.py - pytest -v -s distributed/test_eplb_execute.py
- pytest -v -s distributed/test_eplb_spec_decode.py - pytest -v -s distributed/test_eplb_spec_decode.py
- label: Elastic EP Scaling Test
timeout_in_minutes: 20
device: b200
optional: true
working_dir: "/vllm-workspace/tests"
num_devices: 4
source_file_dependencies:
- vllm/distributed/
- vllm/engine/
- vllm/executor/
- vllm/compilation/
- tests/distributed/
commands:
- pytest -v -s distributed/test_elastic_ep.py
...@@ -316,7 +316,6 @@ def async_tp_pass_on_test_model( ...@@ -316,7 +316,6 @@ def async_tp_pass_on_test_model(
# initialize distributed # initialize distributed
init_distributed_environment() init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass # configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig() vllm_config = VllmConfig()
...@@ -334,11 +333,10 @@ def async_tp_pass_on_test_model( ...@@ -334,11 +333,10 @@ def async_tp_pass_on_test_model(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42 model=model_name, trust_remote_code=True, dtype=dtype, seed=42
) )
async_tp_pass = AsyncTPPass(vllm_config)
# Set the global vllm_config for TestBackend which calls
# get_current_vllm_config()
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
initialize_model_parallel(tensor_model_parallel_size=world_size)
async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass) backend = TestBackend(async_tp_pass)
assert ( assert (
......
...@@ -278,7 +278,6 @@ def all_reduce_fusion_pass_on_test_model( ...@@ -278,7 +278,6 @@ def all_reduce_fusion_pass_on_test_model(
) )
init_distributed_environment() init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
custom_ops = [] custom_ops = []
if enable_rms_norm_custom_op: if enable_rms_norm_custom_op:
...@@ -304,6 +303,7 @@ def all_reduce_fusion_pass_on_test_model( ...@@ -304,6 +303,7 @@ def all_reduce_fusion_pass_on_test_model(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42 model=model_name, trust_remote_code=True, dtype=dtype, seed=42
) )
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
initialize_model_parallel(tensor_model_parallel_size=world_size)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config)
......
...@@ -242,7 +242,6 @@ def sequence_parallelism_pass_on_test_model( ...@@ -242,7 +242,6 @@ def sequence_parallelism_pass_on_test_model(
# initialize distributed # initialize distributed
init_distributed_environment() init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass # configure vllm config for SequenceParallelismPass
custom_ops_list = custom_ops.split(",") if custom_ops else [] custom_ops_list = custom_ops.split(",") if custom_ops else []
...@@ -272,6 +271,7 @@ def sequence_parallelism_pass_on_test_model( ...@@ -272,6 +271,7 @@ def sequence_parallelism_pass_on_test_model(
) )
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
initialize_model_parallel(tensor_model_parallel_size=world_size)
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config) sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
......
...@@ -176,7 +176,11 @@ def init_test_http_connection(): ...@@ -176,7 +176,11 @@ def init_test_http_connection():
@pytest.fixture @pytest.fixture
def dist_init(): def dist_init():
from tests.utils import ensure_current_vllm_config
temp_file = tempfile.mkstemp()[1] temp_file = tempfile.mkstemp()[1]
with ensure_current_vllm_config():
init_distributed_environment( init_distributed_environment(
world_size=1, world_size=1,
rank=0, rank=0,
......
...@@ -7,6 +7,7 @@ import random ...@@ -7,6 +7,7 @@ import random
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
init_distributed_environment, init_distributed_environment,
) )
...@@ -42,6 +43,10 @@ def set_env_vars_and_device(env: dict[str, str]) -> None: ...@@ -42,6 +43,10 @@ def set_env_vars_and_device(env: dict[str, str]) -> None:
local_rank = os.environ["LOCAL_RANK"] local_rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
# Create a minimal vllm config for init_distributed_environment
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
init_distributed_environment() init_distributed_environment()
# Ensure each worker process has the same random seed # Ensure each worker process has the same random seed
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import subprocess
import time
import pytest
import requests
from ..evals.gsm8k.gsm8k_eval import evaluate_gsm8k
from ..utils import RemoteOpenAIServer, multi_gpu_test
@pytest.fixture(autouse=True)
def cleanup_ray_between_tests():
"""Force-stop any lingering Ray processes between tests."""
subprocess.run(["ray", "stop", "--force"], timeout=30, capture_output=True)
time.sleep(5)
yield
MODEL_NAME = "deepseek-ai/DeepSeek-V2-Lite-Chat"
NUM_GSM8K_QUESTIONS = 256
EXPECTED_ACCURACY = 0.58
ACCURACY_TOL = 0.08
MAX_NUM_SEQS = 32
def _send_scale_command(server: RemoteOpenAIServer, new_dp_size: int) -> bool:
url = server.url_for("scale_elastic_ep")
payload = {"new_data_parallel_size": new_dp_size}
headers = {"Content-Type": "application/json"}
try:
response = requests.post(url, json=payload, headers=headers, timeout=300)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
def _run_gsm8k_eval(server: RemoteOpenAIServer, stage: str) -> float:
assert server.port is not None
result = evaluate_gsm8k(
num_questions=NUM_GSM8K_QUESTIONS,
host=f"http://{server.host}",
port=server.port,
)
accuracy = result["accuracy"]
print(
f"[{stage}] GSM8K accuracy: {accuracy:.3f} "
f"({result['num_questions']} questions)"
)
assert accuracy >= EXPECTED_ACCURACY, (
f"[{stage}] GSM8K accuracy {accuracy:.3f} is below "
f"expected threshold {EXPECTED_ACCURACY}"
)
return accuracy
@multi_gpu_test(num_gpus=4)
def test_elastic_ep_scaling():
vllm_serve_args = [
"--trust-remote-code",
"--tensor-parallel-size",
"1",
"--gpu-memory-utilization",
"0.8",
"--max-model-len",
"4096",
"--max-num-seqs",
str(MAX_NUM_SEQS),
"--enable-expert-parallel",
"--all2all-backend",
"allgather_reducescatter",
"--enable-elastic-ep",
"--enable-eplb",
"--eplb-config.num_redundant_experts",
"0",
"--data-parallel-backend",
"ray",
"--data-parallel-size",
"2",
"--api-server-count",
"1",
]
leader_address = os.environ.get("LEADER_ADDRESS")
if leader_address:
vllm_serve_args.extend(["--data-parallel-address", leader_address])
with RemoteOpenAIServer(
MODEL_NAME, vllm_serve_args, env_dict={}, max_wait_seconds=1200
) as server:
initial_accuracy = _run_gsm8k_eval(server, "Initial (2 GPUs)")
assert _send_scale_command(server, 4)
time.sleep(10)
scale_up_accuracy = _run_gsm8k_eval(server, "After scale up (4 GPUs)")
assert scale_up_accuracy >= initial_accuracy - ACCURACY_TOL, (
f"Scale up accuracy {scale_up_accuracy:.3f} dropped more than "
f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}"
)
assert _send_scale_command(server, 2)
time.sleep(5)
scale_down_accuracy = _run_gsm8k_eval(server, "After scale down (2 GPUs)")
assert scale_down_accuracy >= initial_accuracy - ACCURACY_TOL, (
f"Scale down accuracy {scale_down_accuracy:.3f} dropped more than "
f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}"
)
print("\nAccuracy Summary:")
print(f" Initial: {initial_accuracy:.3f}")
print(
f" Scale up: {scale_up_accuracy:.3f} "
f"(diff: {scale_up_accuracy - initial_accuracy:+.3f})"
)
print(
f" Scale down: {scale_down_accuracy:.3f} "
f"(diff: {scale_down_accuracy - initial_accuracy:+.3f})"
)
print(f" Tolerance: {ACCURACY_TOL:.3f}")
@multi_gpu_test(num_gpus=4)
def test_elastic_ep_scaling_uneven():
"""Test scale up with uneven worker distribution.
This tests the case where num_new_workers % old_dp_size != 0,
specifically 2 -> 3 where remainder = 1 % 2 = 1.
This exercises the remainder handling in sender-receiver pairing.
"""
vllm_serve_args = [
"--trust-remote-code",
"--tensor-parallel-size",
"1",
"--gpu-memory-utilization",
"0.8",
"--max-model-len",
"4096",
"--max-num-seqs",
str(MAX_NUM_SEQS),
"--enable-expert-parallel",
"--all2all-backend",
"allgather_reducescatter",
"--enable-elastic-ep",
"--enable-eplb",
"--eplb-config.num_redundant_experts",
"0",
"--data-parallel-backend",
"ray",
"--data-parallel-size",
"2",
"--api-server-count",
"1",
]
leader_address = os.environ.get("LEADER_ADDRESS")
if leader_address:
vllm_serve_args.extend(["--data-parallel-address", leader_address])
with RemoteOpenAIServer(
MODEL_NAME, vllm_serve_args, env_dict={}, max_wait_seconds=1200
) as server:
initial_accuracy = _run_gsm8k_eval(server, "Initial (2 GPUs)")
# Scale 2 -> 3: This has remainder = 1 % 2 = 1
# Tests uneven sender-receiver pairing
assert _send_scale_command(server, 3)
time.sleep(10)
scale_up_accuracy = _run_gsm8k_eval(server, "After scale up (3 GPUs)")
assert scale_up_accuracy >= initial_accuracy - ACCURACY_TOL, (
f"Scale up accuracy {scale_up_accuracy:.3f} dropped more than "
f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}"
)
# Scale back down to 2
assert _send_scale_command(server, 2)
time.sleep(5)
scale_down_accuracy = _run_gsm8k_eval(server, "After scale down (2 GPUs)")
assert scale_down_accuracy >= initial_accuracy - ACCURACY_TOL, (
f"Scale down accuracy {scale_down_accuracy:.3f} dropped more than "
f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}"
)
print("\nAccuracy Summary (Uneven Scaling):")
print(f" Initial: {initial_accuracy:.3f}")
print(
f" Scale up: {scale_up_accuracy:.3f} "
f"(diff: {scale_up_accuracy - initial_accuracy:+.3f})"
)
print(
f" Scale down: {scale_down_accuracy:.3f} "
f"(diff: {scale_down_accuracy - initial_accuracy:+.3f})"
)
print(f" Tolerance: {ACCURACY_TOL:.3f}")
...@@ -8,6 +8,7 @@ import pytest ...@@ -8,6 +8,7 @@ import pytest
import torch import torch
import torch.distributed import torch.distributed
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.rebalance_execute import ( from vllm.distributed.eplb.rebalance_execute import (
move_from_buffer, move_from_buffer,
rearrange_expert_weights_inplace, rearrange_expert_weights_inplace,
...@@ -244,6 +245,11 @@ def _test_async_transfer_layer_without_mtp_worker( ...@@ -244,6 +245,11 @@ def _test_async_transfer_layer_without_mtp_worker(
num_logical_experts: int, num_logical_experts: int,
) -> None: ) -> None:
set_env_vars_and_device(env) set_env_vars_and_device(env)
vllm_config = VllmConfig()
vllm_config.parallel_config.tensor_parallel_size = world_size
with set_current_vllm_config(vllm_config):
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
) )
...@@ -336,6 +342,11 @@ def _test_rearrange_expert_weights_with_redundancy( ...@@ -336,6 +342,11 @@ def _test_rearrange_expert_weights_with_redundancy(
# Initialize model parallel (using tensor parallel as an entrypoint # Initialize model parallel (using tensor parallel as an entrypoint
# to expert parallel) # to expert parallel)
set_env_vars_and_device(env) set_env_vars_and_device(env)
vllm_config = VllmConfig()
vllm_config.parallel_config.tensor_parallel_size = world_size
with set_current_vllm_config(vllm_config):
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
) )
...@@ -444,6 +455,11 @@ def test_rearrange_expert_weights_with_redundancy( ...@@ -444,6 +455,11 @@ def test_rearrange_expert_weights_with_redundancy(
def _test_rearrange_expert_weights_no_change(env, world_size) -> None: def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
set_env_vars_and_device(env) set_env_vars_and_device(env)
vllm_config = VllmConfig()
vllm_config.parallel_config.tensor_parallel_size = world_size
with set_current_vllm_config(vllm_config):
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
) )
...@@ -538,6 +554,11 @@ def test_rearrange_expert_weights_no_change(world_size): ...@@ -538,6 +554,11 @@ def test_rearrange_expert_weights_no_change(world_size):
def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None: def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
set_env_vars_and_device(env) set_env_vars_and_device(env)
vllm_config = VllmConfig()
vllm_config.parallel_config.tensor_parallel_size = world_size
with set_current_vllm_config(vllm_config):
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
) )
......
...@@ -10,6 +10,7 @@ import torch.distributed as dist ...@@ -10,6 +10,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import vllm.envs as envs import vllm.envs as envs
from tests.utils import ensure_current_vllm_config
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops
...@@ -51,6 +52,7 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int): ...@@ -51,6 +52,7 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
) )
init_distributed_environment() init_distributed_environment()
with ensure_current_vllm_config():
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
cuda_communicator = typing.cast( cuda_communicator = typing.cast(
......
...@@ -9,6 +9,7 @@ import pytest ...@@ -9,6 +9,7 @@ import pytest
import torch import torch
import torch.distributed import torch.distributed
from tests.utils import ensure_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
...@@ -112,6 +113,7 @@ def test_pynccl_multiple_allreduce(): ...@@ -112,6 +113,7 @@ def test_pynccl_multiple_allreduce():
@worker_fn_wrapper @worker_fn_wrapper
def multiple_allreduce_with_vllm_worker_fn(): def multiple_allreduce_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}") device = torch.device(f"cuda:{torch.distributed.get_rank()}")
with ensure_current_vllm_config():
ensure_model_parallel_initialized(2, 2) ensure_model_parallel_initialized(2, 2)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with graph_capture(device=device): with graph_capture(device=device):
......
...@@ -6,7 +6,7 @@ import unittest ...@@ -6,7 +6,7 @@ import unittest
import pytest import pytest
import torch import torch
from tests.utils import multi_gpu_test from tests.utils import ensure_current_vllm_config, multi_gpu_test
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
init_distributed_environment, init_distributed_environment,
initialize_model_parallel, initialize_model_parallel,
...@@ -87,6 +87,7 @@ def mixer2_gated_norm_tensor_parallel( ...@@ -87,6 +87,7 @@ def mixer2_gated_norm_tensor_parallel(
# initialize distributed # initialize distributed
init_distributed_environment() init_distributed_environment()
with ensure_current_vllm_config():
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
# create random weights an inputs # create random weights an inputs
......
...@@ -45,12 +45,15 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): ...@@ -45,12 +45,15 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
@pytest.fixture @pytest.fixture
def dist_init(): def dist_init():
from tests.utils import ensure_current_vllm_config
temp_file = tempfile.mkstemp()[1] temp_file = tempfile.mkstemp()[1]
backend = "nccl" backend = "nccl"
if current_platform.is_cpu() or current_platform.is_tpu(): if current_platform.is_cpu() or current_platform.is_tpu():
backend = "gloo" backend = "gloo"
with ensure_current_vllm_config():
init_distributed_environment( init_distributed_environment(
world_size=1, world_size=1,
rank=0, rank=0,
......
...@@ -6,7 +6,7 @@ import random ...@@ -6,7 +6,7 @@ import random
import pytest import pytest
import torch import torch
from tests.utils import multi_gpu_test from tests.utils import ensure_current_vllm_config, multi_gpu_test
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import ( from vllm.distributed import (
init_distributed_environment, init_distributed_environment,
...@@ -631,6 +631,7 @@ def use_fused_moe_lora_kernel_tensor_parallel( ...@@ -631,6 +631,7 @@ def use_fused_moe_lora_kernel_tensor_parallel(
local_rank=local_rank, local_rank=local_rank,
distributed_init_method=init_method, distributed_init_method=init_method,
) )
with ensure_current_vllm_config():
initialize_model_parallel(world_size, 1) initialize_model_parallel(world_size, 1)
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
......
...@@ -13,6 +13,7 @@ from vllm.config import ( ...@@ -13,6 +13,7 @@ from vllm.config import (
ParallelConfig, ParallelConfig,
SchedulerConfig, SchedulerConfig,
VllmConfig, VllmConfig,
set_current_vllm_config,
) )
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
...@@ -77,6 +78,7 @@ def test_worker_apply_lora(qwen3_lora_files): ...@@ -77,6 +78,7 @@ def test_worker_apply_lora(qwen3_lora_files):
distributed_init_method=f"file://{tempfile.mkstemp()[1]}", distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
) )
with set_current_vllm_config(vllm_config):
worker.init_device() worker.init_device()
worker.load_model() worker.load_model()
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from tests.utils import multi_gpu_test from tests.utils import ensure_current_vllm_config, multi_gpu_test
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
init_distributed_environment, init_distributed_environment,
...@@ -117,6 +117,7 @@ def run_dp_sharded_vision_model_vs_direct( ...@@ -117,6 +117,7 @@ def run_dp_sharded_vision_model_vs_direct(
# initialize distributed # initialize distributed
init_distributed_environment() init_distributed_environment()
with ensure_current_vllm_config():
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create a test input tensor # Create a test input tensor
...@@ -302,6 +303,7 @@ def run_dp_sharded_mrope_vision_model_vs_direct( ...@@ -302,6 +303,7 @@ def run_dp_sharded_mrope_vision_model_vs_direct(
# initialize distributed # initialize distributed
init_distributed_environment() init_distributed_environment()
with ensure_current_vllm_config():
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create test data # Create test data
...@@ -377,6 +379,7 @@ def run_dp_sharded_mrope_vision_model_empty_input_worker( ...@@ -377,6 +379,7 @@ def run_dp_sharded_mrope_vision_model_empty_input_worker(
) )
init_distributed_environment() init_distributed_environment()
with ensure_current_vllm_config():
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create empty inputs # Create empty inputs
...@@ -425,6 +428,7 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker( ...@@ -425,6 +428,7 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker(
) )
init_distributed_environment() init_distributed_environment()
with ensure_current_vllm_config():
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create images with very different sizes # Create images with very different sizes
......
...@@ -895,6 +895,36 @@ def compare_all_settings( ...@@ -895,6 +895,36 @@ def compare_all_settings(
) )
@contextmanager
def ensure_current_vllm_config():
"""Ensures a vllm config is set for the duration of the context.
If a config is already set, this is a no-op. Otherwise, it creates a default
VllmConfig and sets it for the duration of the context.
Used for tests that call functions which require a vllm config but don't
need a specific config.
Example:
with ensure_current_vllm_config():
init_distributed_environment(...)
ensure_model_parallel_initialized(...)
"""
from vllm.config import (
VllmConfig,
get_current_vllm_config_or_none,
set_current_vllm_config,
)
if get_current_vllm_config_or_none() is not None:
# Config already set, just yield
yield
else:
# No config set, create a default one for the duration
with set_current_vllm_config(VllmConfig()):
yield
def init_test_distributed_environment( def init_test_distributed_environment(
tp_size: int, tp_size: int,
pp_size: int, pp_size: int,
...@@ -921,6 +951,7 @@ def init_test_distributed_environment( ...@@ -921,6 +951,7 @@ def init_test_distributed_environment(
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
local_rank=local_rank, local_rank=local_rank,
) )
ensure_model_parallel_initialized(tp_size, pp_size)
else: else:
# No config set, create a default one for the test # No config set, create a default one for the test
with set_current_vllm_config(VllmConfig()): with set_current_vllm_config(VllmConfig()):
......
...@@ -789,6 +789,9 @@ def test_hybrid_attention_mamba_tensor_shapes(): ...@@ -789,6 +789,9 @@ def test_hybrid_attention_mamba_tensor_shapes():
"MASTER_PORT": "12345", "MASTER_PORT": "12345",
} }
) )
from tests.utils import ensure_current_vllm_config
with ensure_current_vllm_config():
init_distributed_environment() init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=1) initialize_model_parallel(tensor_model_parallel_size=1)
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)
......
...@@ -10,6 +10,7 @@ from unittest.mock import patch ...@@ -10,6 +10,7 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.config import set_current_vllm_config
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.utils.mem_utils import MemorySnapshot from vllm.utils.mem_utils import MemorySnapshot
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
...@@ -95,7 +96,12 @@ def worker_process( ...@@ -95,7 +96,12 @@ def worker_process(
side_effect=make_operation_tracker("nccl_all_reduce", original_all_reduce), side_effect=make_operation_tracker("nccl_all_reduce", original_all_reduce),
) )
with init_patch, memory_patch, all_reduce_patch: with (
init_patch,
memory_patch,
all_reduce_patch,
set_current_vllm_config(vllm_config),
):
# Initialize device (this is where we test the order) # Initialize device (this is where we test the order)
worker.init_device() worker.init_device()
......
...@@ -319,3 +319,52 @@ class TorchCompileWithNoGuardsWrapper: ...@@ -319,3 +319,52 @@ class TorchCompileWithNoGuardsWrapper:
yield yield
finally: finally:
self.__class__.forward.__code__ = original self.__class__.forward.__code__ = original
def reset_compile_wrapper(model: torch.nn.Module) -> None:
"""
Clean up compiled model and captured CUDA graphs for elastic EP.
"""
if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr(
model, "model"
):
model = model.model
if not isinstance(model, TorchCompileWithNoGuardsWrapper):
return
# model.do_not_compile is set by the @support_torch_compile decorator
if hasattr(model, "do_not_compile") and model.do_not_compile:
return
from vllm.compilation.counter import compilation_counter
# reset the compilation counter
compilation_counter.num_models_seen = 0
compilation_counter.num_graphs_seen = 0
compilation_counter.num_piecewise_graphs_seen = 0
compilation_counter.num_piecewise_capturable_graphs_seen = 0
compilation_counter.num_backend_compilations = 0
compilation_counter.num_gpu_runner_capture_triggers = 0
compilation_counter.num_cudagraph_captured = 0
compilation_counter.num_inductor_compiles = 0
compilation_counter.num_eager_compiles = 0
compilation_counter.num_cache_entries_updated = 0
compilation_counter.num_compiled_artifacts_saved = 0
compilation_counter.stock_torch_compile_count = 0
# Clear the AOT compiled function so the model is forced to
# recompile on the next call. Without this, decorators.py
# __call__ uses the stale aot_compiled_fn whose torchinductor
# kernels have old parameters (expert_map size for example)
# baked in as compile-time constants.
if hasattr(model, "aot_compiled_fn"):
model.aot_compiled_fn = None
if hasattr(model, "was_aot_compile_fn_loaded_from_disk"):
model.was_aot_compile_fn_loaded_from_disk = False
# Reset the cache_dir so VllmBackend recomputes the hash
# (data_parallel_size changed, so the config hash differs).
compilation_config = model.vllm_config.compilation_config
compilation_config.cache_dir = ""
compilation_config.local_cache_dir = ""
model.__class__.forward.__code__ = model.original_code_object()
TorchCompileWithNoGuardsWrapper.__init__(model)
...@@ -165,6 +165,9 @@ class ParallelConfig: ...@@ -165,6 +165,9 @@ class ParallelConfig:
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
"""Disable the custom all-reduce kernel and fall back to NCCL.""" """Disable the custom all-reduce kernel and fall back to NCCL."""
enable_elastic_ep: bool = False
"""Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""
enable_dbo: bool = False enable_dbo: bool = False
"""Enable dual batch overlap for the model executor.""" """Enable dual batch overlap for the model executor."""
ubatch_size: int = 0 ubatch_size: int = 0
...@@ -244,6 +247,34 @@ class ParallelConfig: ...@@ -244,6 +247,34 @@ class ParallelConfig:
Set to be private as it's not intended to be configured by users. Set to be private as it's not intended to be configured by users.
""" """
_stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
It is a list of list[int], with each inner list contains a set of 3 ports
to be used for setting up the stateless CPU/device/TCPStore groups
in StatelessGroupCoordinator. The number of inner lists is equal to
the number of DP groups,
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
"""
_stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
"""
_stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless EPLB groups when enable_elastic_ep is True.
Same topology as EP but separate NCCL communicator to avoid deadlocks.
"""
_stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless world group when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_world_group_port_list) == 1,
"""
decode_context_parallel_size: int = 1 decode_context_parallel_size: int = 1
"""Number of decode context parallel groups, because the world size does """Number of decode context parallel groups, because the world size does
not change by dcp, it simply reuse the GPUs of TP group, and tp_size not change by dcp, it simply reuse the GPUs of TP group, and tp_size
...@@ -402,7 +433,67 @@ class ParallelConfig: ...@@ -402,7 +433,67 @@ class ParallelConfig:
return answer return answer
def stateless_init_dp_group(self) -> ProcessGroup: def allocate_elastic_ep_ports(self) -> None:
"""Allocate all ports for elastic EP (stateless groups + DP master).
Must be called AFTER ray.init() so that ports claimed by Ray's
idle worker pool are already in use and won't be returned by
get_open_ports_list().
"""
if not self.enable_elastic_ep:
return
if self._stateless_world_group_port_list:
return
num_world_groups = 1
dp_size = self.data_parallel_size
ep_size = self.data_parallel_size * self.world_size_across_dp
num_dp_groups = max(1, self.world_size_across_dp // dp_size)
num_ep_groups = max(1, self.world_size_across_dp // ep_size)
num_eplb_groups = num_ep_groups
total_stateless_ports = (
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
) * 3
num_dp_master_ports = 5
all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports)
self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:]
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
all_ports = all_ports[:-num_dp_master_ports]
self._stateless_world_group_port_list = [
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
]
start_idx = num_world_groups * 3
self._stateless_dp_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
]
start_idx += num_dp_groups * 3
self._stateless_ep_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
]
start_idx += num_ep_groups * 3
self._stateless_eplb_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
]
def get_next_stateless_world_group_port(self) -> list[int]:
return self._stateless_world_group_port_list.pop()
def get_next_stateless_dp_group_port(self) -> list[int]:
return self._stateless_dp_group_port_list.pop()
def get_next_stateless_ep_group_port(self) -> list[int]:
return self._stateless_ep_group_port_list.pop()
def get_next_stateless_eplb_group_port(self) -> list[int]:
return self._stateless_eplb_group_port_list.pop()
def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
# NOTE: In high-concurrency scenarios multiple processes # NOTE: In high-concurrency scenarios multiple processes
# can pick the same (currently free) port through a race # can pick the same (currently free) port through a race
# condition when calling `get_open_port()`. When the first # condition when calling `get_open_port()`. When the first
...@@ -426,7 +517,8 @@ class ParallelConfig: ...@@ -426,7 +517,8 @@ class ParallelConfig:
self.get_next_dp_init_port(), self.get_next_dp_init_port(),
self.data_parallel_rank, self.data_parallel_rank,
self.data_parallel_size, self.data_parallel_size,
backend=current_platform.dist_backend, backend="gloo",
return_store=return_store,
) )
except DistNetworkError as e: except DistNetworkError as e:
# We only want to retry when the root cause is EADDRINUSE. # We only want to retry when the root cause is EADDRINUSE.
...@@ -561,6 +653,21 @@ class ParallelConfig: ...@@ -561,6 +653,21 @@ class ParallelConfig:
logger.info("Using external launcher for distributed inference.") logger.info("Using external launcher for distributed inference.")
self.world_size *= self.data_parallel_size self.world_size *= self.data_parallel_size
if self.enable_elastic_ep:
if not self.enable_eplb:
raise ValueError("Elastic EP is only supported with enable_eplb=True.")
if self.pipeline_parallel_size > 1:
raise ValueError(
"Elastic EP is not supported with pipeline parallelism "
f"(pipeline_parallel_size={self.pipeline_parallel_size})."
)
if self.data_parallel_external_lb or self.data_parallel_hybrid_lb:
raise NotImplementedError(
"Elastic EP is not compatible with data_parallel_external_lb "
"or data_parallel_hybrid_lb. Elastic EP relies on a single API "
"server and core client to coordinate scale up/down."
)
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args. # Data parallel was specified in the engine args.
if self.distributed_executor_backend == "external_launcher": if self.distributed_executor_backend == "external_launcher":
...@@ -573,9 +680,12 @@ class ParallelConfig: ...@@ -573,9 +680,12 @@ class ParallelConfig:
"Set data_parallel_rank to %d automatically.", "Set data_parallel_rank to %d automatically.",
self.data_parallel_rank, self.data_parallel_rank,
) )
if not self.enable_elastic_ep:
if not self._data_parallel_master_port_list: if not self._data_parallel_master_port_list:
self._data_parallel_master_port_list = get_open_ports_list(5) self._data_parallel_master_port_list = get_open_ports_list(5)
self.data_parallel_master_port = self._data_parallel_master_port_list.pop() self.data_parallel_master_port = (
self._data_parallel_master_port_list.pop()
)
if not (0 <= self.data_parallel_rank < self.data_parallel_size): if not (0 <= self.data_parallel_rank < self.data_parallel_size):
raise ValueError( raise ValueError(
...@@ -602,7 +712,7 @@ class ParallelConfig: ...@@ -602,7 +712,7 @@ class ParallelConfig:
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.") logger.info("Disabling V1 multiprocessing for external launcher.")
if self.distributed_executor_backend is None and self.world_size > 1: if self.distributed_executor_backend is None and self.world_size_across_dp > 1:
# We use multiprocessing by default if world_size fits on the # We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group. # current node and we aren't in a ray placement group.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment