"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "40d0e7411dbeb276befd33c4485115ac3d4d7f2a"
Unverified Commit abdbb683 authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

[EPLB] Add alternative communication for EPLB weight exchange (#33176)


Signed-off-by: default avatarilmarkov <markovilya197@gmail.com>
Signed-off-by: default avatarMarkov Ilya <markovilya19@gmail.com>
Co-authored-by: default avatarMarkov Ilya <markovilya19@gmail.com>
parent 0c637391
...@@ -13,8 +13,8 @@ steps: ...@@ -13,8 +13,8 @@ steps:
- pytest -v -s distributed/test_eplb_algo.py - pytest -v -s distributed/test_eplb_algo.py
- pytest -v -s distributed/test_eplb_utils.py - pytest -v -s distributed/test_eplb_utils.py
- label: EPLB Execution - label: EPLB Execution # 17min
timeout_in_minutes: 20 timeout_in_minutes: 27
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_devices: 4 num_devices: 4
source_file_dependencies: source_file_dependencies:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import atexit
import os import os
import random import random
import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -16,9 +18,20 @@ from vllm.utils.system_utils import update_environment_variables ...@@ -16,9 +18,20 @@ from vllm.utils.system_utils import update_environment_variables
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
def _distributed_worker_wrapper(fn, env, world_size, args, rank, skip_queue):
try:
fn(env, world_size, *args)
except BaseException as exc:
if isinstance(exc, pytest.skip.Exception):
skip_queue.put((rank, str(exc)))
return
raise
def distributed_run(fn, world_size, *args): def distributed_run(fn, world_size, *args):
number_of_processes = world_size number_of_processes = world_size
processes: list[mp.Process] = [] processes: list[mp.Process] = []
skip_queue: mp.SimpleQueue = mp.SimpleQueue()
for i in range(number_of_processes): for i in range(number_of_processes):
env: dict[str, str] = {} env: dict[str, str] = {}
env["RANK"] = str(i) env["RANK"] = str(i)
...@@ -27,13 +40,32 @@ def distributed_run(fn, world_size, *args): ...@@ -27,13 +40,32 @@ def distributed_run(fn, world_size, *args):
env["LOCAL_WORLD_SIZE"] = str(number_of_processes) env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost" env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345" env["MASTER_PORT"] = "12345"
p = mp.Process(target=fn, args=(env, world_size, *args)) p = mp.Process(
target=_distributed_worker_wrapper,
args=(fn, env, world_size, args, i, skip_queue),
)
processes.append(p) processes.append(p)
p.start() p.start()
for p in processes: for p in processes:
p.join() p.join()
skipped: list[tuple[int, str]] = []
while not skip_queue.empty():
rank, reason = skip_queue.get()
skipped.append((rank, reason))
if len(skipped) == number_of_processes:
reason = skipped[0][1]
pytest.skip(reason)
if 0 < len(skipped) < number_of_processes:
skipped_ranks = sorted(rank for rank, _ in skipped)
raise AssertionError(
"Distributed test had partial skips; expected either all ranks "
f"to skip or none. Skipped ranks: {skipped_ranks}, "
f"total ranks: {number_of_processes}"
)
for p in processes: for p in processes:
assert p.exitcode == 0 assert p.exitcode == 0
...@@ -48,7 +80,12 @@ def set_env_vars_and_device(env: dict[str, str]) -> None: ...@@ -48,7 +80,12 @@ def set_env_vars_and_device(env: dict[str, str]) -> None:
vllm_config = VllmConfig() vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
init_distributed_environment() init_distributed_environment()
atexit.register(_destroy_process_group_if_initialized)
# Ensure each worker process has the same random seed # Ensure each worker process has the same random seed
random.seed(42) random.seed(42)
torch.manual_seed(42) torch.manual_seed(42)
def _destroy_process_group_if_initialized() -> None:
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
import torch.distributed import torch.distributed
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
from vllm.distributed.eplb.rebalance_execute import ( from vllm.distributed.eplb.rebalance_execute import (
move_from_buffer, move_from_buffer,
rearrange_expert_weights_inplace, rearrange_expert_weights_inplace,
...@@ -130,9 +131,10 @@ def verify_expert_weights_after_shuffle( ...@@ -130,9 +131,10 @@ def verify_expert_weights_after_shuffle(
hidden_sizes: list[int], hidden_sizes: list[int],
ep_rank: int, ep_rank: int,
num_local_experts: int, num_local_experts: int,
): ) -> bool:
"""Verify the weights after shuffling are correct.""" """Verify the weights after shuffling are correct."""
num_layers = len(expert_weights) num_layers = len(expert_weights)
ok = True
for layer in range(num_layers): for layer in range(num_layers):
for weight_idx, hidden_size in enumerate(hidden_sizes): for weight_idx, hidden_size in enumerate(hidden_sizes):
...@@ -155,29 +157,38 @@ def verify_expert_weights_after_shuffle( ...@@ -155,29 +157,38 @@ def verify_expert_weights_after_shuffle(
dtype=actual_weights.dtype, dtype=actual_weights.dtype,
) )
torch.testing.assert_close( if not torch.equal(actual_weights, expected_weights):
actual_weights, ok = False
expected_weights, actual_head = actual_weights[:8].detach().cpu().tolist()
msg=f"Layer {layer}, weight {weight_idx}," expected_head = expected_weights[:8].detach().cpu().tolist()
f"local expert {local_expert}: " print(
f"weights do not match. " "verify_expert_weights_after_shuffle failed: "
f"Expected logical expert {expected_logical_expert}", f"rank={ep_rank}, "
) f"layer={layer}, weight_idx={weight_idx}, "
f"local_expert={local_expert}, "
f"expected_logical_expert={expected_logical_expert}, "
f"actual_head={actual_head}, expected_head={expected_head}",
flush=True,
)
return ok
def verify_redundant_experts_have_same_weights( def verify_redundant_experts_have_same_weights(
expert_weights: list[list[torch.Tensor]], expert_weights: list[list[torch.Tensor]],
indices: torch.Tensor, indices: torch.Tensor,
hidden_sizes: list[int], hidden_sizes: list[int],
ep_rank: int,
world_size: int, world_size: int,
num_local_experts: int, num_local_experts: int,
): ) -> bool:
""" """
Verify that all replicas of the same logical expert have the same weights. Verify that all replicas of the same logical expert have the same weights.
""" """
num_layers = len(expert_weights) num_layers = len(expert_weights)
total_physical_experts = world_size * num_local_experts total_physical_experts = world_size * num_local_experts
ok = True
for layer in range(num_layers): for layer in range(num_layers):
# Collect weights for all physical experts for each weight matrix # Collect weights for all physical experts for each weight matrix
all_weights: list[torch.Tensor] = [] all_weights: list[torch.Tensor] = []
...@@ -227,14 +238,54 @@ def verify_redundant_experts_have_same_weights( ...@@ -227,14 +238,54 @@ def verify_redundant_experts_have_same_weights(
# Verify that current physical expert's weights match the # Verify that current physical expert's weights match the
# previously saved logical expert weights # previously saved logical expert weights
for weight_idx in range(len(hidden_sizes)): for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close( if not torch.equal(
all_weights[weight_idx][physical_pos], all_weights[weight_idx][physical_pos],
logical_expert_weights[logical_expert_id][weight_idx], logical_expert_weights[logical_expert_id][weight_idx],
msg=f"Layer {layer}, weight {weight_idx}," ):
f"logical expert {logical_expert_id}: " ok = False
f"Physical expert {physical_pos} has different weights" actual_head = (
f"than expected", all_weights[weight_idx][physical_pos][:8]
) .detach()
.cpu()
.tolist()
)
reference_head = (
logical_expert_weights[logical_expert_id][weight_idx][:8]
.detach()
.cpu()
.tolist()
)
print(
"verify_redundant_experts_have_same_weights failed: "
f"rank={ep_rank}, "
f"layer={layer}, weight_idx={weight_idx}, "
f"logical_expert={logical_expert_id}, "
f"physical_pos={physical_pos}, "
f"actual_head={actual_head}, "
f"reference_head={reference_head}",
flush=True,
)
return ok
def assert_verification_synced(local_ok: bool, msg: str) -> None:
ok_tensor = torch.tensor([1 if local_ok else 0], device="cuda", dtype=torch.int32)
torch.distributed.all_reduce(ok_tensor, op=torch.distributed.ReduceOp.MIN)
assert bool(ok_tensor.item()), msg
def create_eplb_communicator_or_raise(*, group_coordinator, backend, expert_weights):
try:
return create_eplb_communicator(
group_coordinator=group_coordinator,
backend=backend,
expert_weights=expert_weights,
)
except Exception as exc:
raise RuntimeError(
f"Failed to create EPLB communicator for backend={backend}: {exc}"
) from exc
def _test_async_transfer_layer_without_mtp_worker( def _test_async_transfer_layer_without_mtp_worker(
...@@ -243,6 +294,7 @@ def _test_async_transfer_layer_without_mtp_worker( ...@@ -243,6 +294,7 @@ def _test_async_transfer_layer_without_mtp_worker(
num_layers: int, num_layers: int,
num_local_experts: int, num_local_experts: int,
num_logical_experts: int, num_logical_experts: int,
eplb_communicator: str,
) -> None: ) -> None:
set_env_vars_and_device(env) set_env_vars_and_device(env)
...@@ -254,8 +306,8 @@ def _test_async_transfer_layer_without_mtp_worker( ...@@ -254,8 +306,8 @@ def _test_async_transfer_layer_without_mtp_worker(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
) )
tp_group = get_tp_group() ep_group_coordinator = get_tp_group()
ep_group = tp_group.device_group ep_group = ep_group_coordinator.device_group
ep_rank = torch.distributed.get_rank() ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}") device = torch.device(f"cuda:{ep_rank}")
...@@ -298,6 +350,13 @@ def _test_async_transfer_layer_without_mtp_worker( ...@@ -298,6 +350,13 @@ def _test_async_transfer_layer_without_mtp_worker(
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]] expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
cuda_stream = torch.cuda.Stream(device=device) cuda_stream = torch.cuda.Stream(device=device)
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend=eplb_communicator,
expert_weights=expert_weights[0],
)
communicator.set_stream(cuda_stream)
for layer_idx in range(num_layers): for layer_idx in range(num_layers):
is_unchanged, is_received_locally, recv_metadata = asyncio.run( is_unchanged, is_received_locally, recv_metadata = asyncio.run(
transfer_layer( transfer_layer(
...@@ -306,6 +365,7 @@ def _test_async_transfer_layer_without_mtp_worker( ...@@ -306,6 +365,7 @@ def _test_async_transfer_layer_without_mtp_worker(
expert_weights=expert_weights[layer_idx], expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_buffer, expert_weights_buffer=expert_buffer,
ep_group=ep_group, ep_group=ep_group,
communicator=communicator,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
) )
) )
...@@ -320,24 +380,38 @@ def _test_async_transfer_layer_without_mtp_worker( ...@@ -320,24 +380,38 @@ def _test_async_transfer_layer_without_mtp_worker(
ep_rank=ep_rank, ep_rank=ep_rank,
) )
verify_expert_weights_after_shuffle( local_ok = verify_expert_weights_after_shuffle(
expert_weights, expert_weights,
new_indices, new_indices,
hidden_sizes, hidden_sizes,
ep_rank, ep_rank,
num_local_experts, num_local_experts,
) )
local_ok = (
verify_redundant_experts_have_same_weights( verify_redundant_experts_have_same_weights(
expert_weights, expert_weights,
new_indices, new_indices,
hidden_sizes, hidden_sizes,
ep_rank,
world_size, world_size,
num_local_experts, num_local_experts,
) )
and local_ok
)
assert_verification_synced(
local_ok,
"Async transfer verification failed on at least one rank. "
"See logs for details.",
)
def _test_rearrange_expert_weights_with_redundancy( def _test_rearrange_expert_weights_with_redundancy(
env, world_size, num_layers, num_local_experts, num_logical_experts env,
world_size,
num_layers,
num_local_experts,
num_logical_experts,
eplb_communicator: str,
) -> None: ) -> None:
# Initialize model parallel (using tensor parallel as an entrypoint # Initialize model parallel (using tensor parallel as an entrypoint
# to expert parallel) # to expert parallel)
...@@ -351,7 +425,8 @@ def _test_rearrange_expert_weights_with_redundancy( ...@@ -351,7 +425,8 @@ def _test_rearrange_expert_weights_with_redundancy(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
) )
ep_group = get_tp_group().cpu_group ep_group_coordinator = get_tp_group()
ep_group = ep_group_coordinator.cpu_group
ep_rank = torch.distributed.get_rank() ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}") device = torch.device(f"cuda:{ep_rank}")
...@@ -387,6 +462,12 @@ def _test_rearrange_expert_weights_with_redundancy( ...@@ -387,6 +462,12 @@ def _test_rearrange_expert_weights_with_redundancy(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
) )
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend=eplb_communicator,
expert_weights=expert_weights[0],
)
# Execute weight rearrangement # Execute weight rearrangement
rearrange_expert_weights_inplace( rearrange_expert_weights_inplace(
old_indices, old_indices,
...@@ -394,24 +475,33 @@ def _test_rearrange_expert_weights_with_redundancy( ...@@ -394,24 +475,33 @@ def _test_rearrange_expert_weights_with_redundancy(
expert_weights, expert_weights,
ep_group, ep_group,
is_profile=False, is_profile=False,
communicator=communicator,
) )
# Verify the rearrangement result # Verify the rearrangement result
verify_expert_weights_after_shuffle( local_ok = verify_expert_weights_after_shuffle(
expert_weights, expert_weights,
new_indices, new_indices,
hidden_sizes, hidden_sizes,
ep_rank, ep_rank,
num_local_experts, num_local_experts,
) )
local_ok = (
verify_redundant_experts_have_same_weights( verify_redundant_experts_have_same_weights(
expert_weights, expert_weights,
new_indices, new_indices,
hidden_sizes, hidden_sizes,
ep_rank,
world_size, world_size,
num_local_experts, num_local_experts,
) )
and local_ok
)
assert_verification_synced(
local_ok,
"Rearrange verification failed on at least one rank. See logs for details.",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -437,8 +527,13 @@ def _test_rearrange_expert_weights_with_redundancy( ...@@ -437,8 +527,13 @@ def _test_rearrange_expert_weights_with_redundancy(
(4, 8, 8, 16), (4, 8, 8, 16),
], ],
) )
@pytest.mark.parametrize("eplb_communicator", ["torch_nccl", "torch_gloo", "pynccl"])
def test_rearrange_expert_weights_with_redundancy( def test_rearrange_expert_weights_with_redundancy(
world_size, num_layers, num_local_experts, num_logical_experts world_size,
num_layers,
num_local_experts,
num_logical_experts,
eplb_communicator,
): ):
"""Test the functionality of rearranging expert weights with redundancy.""" """Test the functionality of rearranging expert weights with redundancy."""
...@@ -450,6 +545,7 @@ def test_rearrange_expert_weights_with_redundancy( ...@@ -450,6 +545,7 @@ def test_rearrange_expert_weights_with_redundancy(
num_layers, num_layers,
num_local_experts, num_local_experts,
num_logical_experts, num_logical_experts,
eplb_communicator,
) )
...@@ -464,7 +560,8 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None: ...@@ -464,7 +560,8 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
) )
ep_group = get_tp_group().cpu_group ep_group_coordinator = get_tp_group()
ep_group = ep_group_coordinator.cpu_group
ep_rank = torch.distributed.get_rank() ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}") device = torch.device(f"cuda:{ep_rank}")
...@@ -494,24 +591,40 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None: ...@@ -494,24 +591,40 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
layer_copy.append(weight.clone()) layer_copy.append(weight.clone())
original_weights.append(layer_copy) original_weights.append(layer_copy)
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend="torch_nccl",
expert_weights=expert_weights[0],
)
# Execute rearrangement (should be no change) # Execute rearrangement (should be no change)
rearrange_expert_weights_inplace( rearrange_expert_weights_inplace(
indices, indices,
indices, # Same indices indices, # Same indices
expert_weights, expert_weights,
ep_group, ep_group,
communicator,
is_profile=False, is_profile=False,
) )
# Verify that the weights have not changed # Verify that the weights have not changed
for layer in range(num_layers): local_ok = True
for weight_idx in range(len(hidden_sizes)): for layer in range(num_layers):
torch.testing.assert_close( for weight_idx in range(len(hidden_sizes)):
expert_weights[layer][weight_idx], if not torch.equal(
original_weights[layer][weight_idx], expert_weights[layer][weight_idx],
msg=f"""Layer {layer}, weight {weight_idx} original_weights[layer][weight_idx],
should remain unchanged""", ):
local_ok = False
print(
"test_rearrange_expert_weights_no_change failed: "
f"layer={layer}, weight_idx={weight_idx}",
flush=True,
) )
assert_verification_synced(
local_ok,
"No-change EPLB verification failed on at least one rank.",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -520,11 +633,13 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None: ...@@ -520,11 +633,13 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
(2, 2, 2, 3), (2, 2, 2, 3),
], ],
) )
@pytest.mark.parametrize("eplb_communicator", ["torch_nccl", "torch_gloo", "pynccl"])
def test_async_transfer_layer_without_mtp( def test_async_transfer_layer_without_mtp(
world_size: int, world_size: int,
num_layers: int, num_layers: int,
num_local_experts: int, num_local_experts: int,
num_logical_experts: int, num_logical_experts: int,
eplb_communicator: str,
): ):
"""Exercise async EPLB transfer path without MTP/spec decode.""" """Exercise async EPLB transfer path without MTP/spec decode."""
...@@ -537,6 +652,7 @@ def test_async_transfer_layer_without_mtp( ...@@ -537,6 +652,7 @@ def test_async_transfer_layer_without_mtp(
num_layers, num_layers,
num_local_experts, num_local_experts,
num_logical_experts, num_logical_experts,
eplb_communicator,
) )
...@@ -549,7 +665,10 @@ def test_rearrange_expert_weights_no_change(world_size): ...@@ -549,7 +665,10 @@ def test_rearrange_expert_weights_no_change(world_size):
if torch.accelerator.device_count() < world_size: if torch.accelerator.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test") pytest.skip(f"Need at least {world_size} GPUs to run the test")
distributed_run(_test_rearrange_expert_weights_no_change, world_size) distributed_run(
_test_rearrange_expert_weights_no_change,
world_size,
)
def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None: def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
...@@ -563,7 +682,8 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None: ...@@ -563,7 +682,8 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
) )
ep_group = get_tp_group().cpu_group ep_group_coordinator = get_tp_group()
ep_group = ep_group_coordinator.cpu_group
ep_rank = torch.distributed.get_rank() ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}") device = torch.device(f"cuda:{ep_rank}")
...@@ -600,23 +720,40 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None: ...@@ -600,23 +720,40 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
layer_copy.append(weight.clone()) layer_copy.append(weight.clone())
original_weights.append(layer_copy) original_weights.append(layer_copy)
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend="torch_nccl",
expert_weights=expert_weights[0],
)
# Execute profile mode rearrangement # Execute profile mode rearrangement
rearrange_expert_weights_inplace( rearrange_expert_weights_inplace(
old_indices, old_indices,
new_indices, new_indices,
expert_weights, expert_weights,
ep_group, ep_group,
communicator,
is_profile=True, # Profile mode is_profile=True, # Profile mode
) )
# In profile mode, the weights should remain unchanged # In profile mode, the weights should remain unchanged
for layer in range(num_layers): local_ok = True
for weight_idx in range(len(hidden_sizes)): for layer in range(num_layers):
torch.testing.assert_close( for weight_idx in range(len(hidden_sizes)):
expert_weights[layer][weight_idx], if not torch.equal(
original_weights[layer][weight_idx], expert_weights[layer][weight_idx],
msg="In profile mode, the weights should remain unchanged", original_weights[layer][weight_idx],
):
local_ok = False
print(
"test_rearrange_expert_weights_profile_mode failed: "
f"layer={layer}, weight_idx={weight_idx}",
flush=True,
) )
assert_verification_synced(
local_ok,
"Profile-mode EPLB verification failed on at least one rank.",
)
@pytest.mark.parametrize("world_size", [2, 4]) @pytest.mark.parametrize("world_size", [2, 4])
...@@ -625,4 +762,7 @@ def test_rearrange_expert_weights_profile_mode(world_size): ...@@ -625,4 +762,7 @@ def test_rearrange_expert_weights_profile_mode(world_size):
if torch.accelerator.device_count() < world_size: if torch.accelerator.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test") pytest.skip(f"Need at least {world_size} GPUs to run the test")
distributed_run(_test_rearrange_expert_weights_profile_mode, world_size) distributed_run(
_test_rearrange_expert_weights_profile_mode,
world_size,
)
...@@ -35,6 +35,7 @@ DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] ...@@ -35,6 +35,7 @@ DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
DataParallelBackend = Literal["ray", "mp"] DataParallelBackend = Literal["ray", "mp"]
EPLBPolicyOption = Literal["default"] EPLBPolicyOption = Literal["default"]
DCPCommBackend = Literal["ag_rs", "a2a"] DCPCommBackend = Literal["ag_rs", "a2a"]
EPLBCommunicatorBackend = Literal["torch_nccl", "torch_gloo", "pynccl"]
All2AllBackend = Literal[ All2AllBackend = Literal[
"naive", "naive",
"pplx", "pplx",
...@@ -83,6 +84,15 @@ class EPLBConfig: ...@@ -83,6 +84,15 @@ class EPLBConfig:
policy: EPLBPolicyOption = "default" policy: EPLBPolicyOption = "default"
"""The policy type for expert parallel load balancing (EPLB).""" """The policy type for expert parallel load balancing (EPLB)."""
communicator: EPLBCommunicatorBackend | None = None
"""
Backend for EPLB expert weight communication:
- "torch_nccl": Use torch.distributed on the device process group
- "torch_gloo": Use torch.distributed gloo with CPU staging
- "pynccl": Use PyNccl send/recv
- None: Auto-select backend ("torch_gloo" for async, "torch_nccl" for sync)
"""
@model_validator(mode="after") @model_validator(mode="after")
def _validate_eplb_config(self) -> Self: def _validate_eplb_config(self) -> Self:
if self.use_async and self.policy != "default": if self.use_async and self.policy != "default":
...@@ -764,16 +774,18 @@ class ParallelConfig: ...@@ -764,16 +774,18 @@ class ParallelConfig:
"backend is mp, uni or external_launcher." "backend is mp, uni or external_launcher."
) )
if ( if self.enable_eplb and self.eplb_config.communicator is None:
self.all2all_backend in ("allgather_reducescatter") if self.enable_elastic_ep:
and self.eplb_config.use_async # Elastic EP requires stateless mode
): # (torch.distributed.batch_isend_irecv doesn't
logger.warning( # support stateless mode), so we use PyNCCL backend
"Async EPLB causes hangs with the '%s' all2all backend. " self.eplb_config.communicator = "pynccl"
"Forcing synchronous EPLB.", elif self.eplb_config.use_async:
self.all2all_backend, # Torch Gloo is a backend that allows avoiding hangs
) # due to NCCL multi-thread conflicts in async EPLB
self.eplb_config.use_async = False self.eplb_config.communicator = "torch_gloo"
else:
self.eplb_config.communicator = "torch_nccl"
@property @property
def use_ray(self) -> bool: def use_ray(self) -> bool:
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
# variable in the code. # variable in the code.
import ctypes import ctypes
import functools
import platform import platform
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
...@@ -75,26 +76,34 @@ class ncclDataTypeEnum: ...@@ -75,26 +76,34 @@ class ncclDataTypeEnum:
ncclFloat8e4m3 = 10 ncclFloat8e4m3 = 10
ncclNumTypes = 11 ncclNumTypes = 11
@classmethod
@functools.lru_cache(maxsize=1)
def _torch_to_nccl_map(cls) -> dict[torch.dtype, int]:
return {
torch.int8: cls.ncclInt8,
torch.uint8: cls.ncclUint8,
torch.int32: cls.ncclInt32,
torch.int64: cls.ncclInt64,
torch.float16: cls.ncclFloat16,
torch.float32: cls.ncclFloat32,
torch.float64: cls.ncclFloat64,
torch.bfloat16: cls.ncclBfloat16,
current_platform.fp8_dtype(): cls.ncclFloat8e4m3,
}
@classmethod
def supports_torch_dtype(cls, dtype: torch.dtype) -> bool:
return dtype in cls._torch_to_nccl_map()
@classmethod
def try_from_torch(cls, dtype: torch.dtype) -> int | None:
return cls._torch_to_nccl_map().get(dtype)
@classmethod @classmethod
def from_torch(cls, dtype: torch.dtype) -> int: def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8: nccl_dtype = cls.try_from_torch(dtype)
return cls.ncclInt8 if nccl_dtype is not None:
if dtype == torch.uint8: return nccl_dtype
return cls.ncclUint8
if dtype == torch.int32:
return cls.ncclInt32
if dtype == torch.int64:
return cls.ncclInt64
if dtype == torch.float16:
return cls.ncclFloat16
if dtype == torch.float32:
return cls.ncclFloat32
if dtype == torch.float64:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
if dtype == current_platform.fp8_dtype():
return cls.ncclFloat8e4m3
raise ValueError( raise ValueError(
f"Unsupported dtype {dtype}: should be one of " f"Unsupported dtype {dtype}: should be one of "
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16," f"int8, uint8, int32, int64, float16, float32, float64, bfloat16,"
......
...@@ -29,8 +29,10 @@ from vllm.distributed.elastic_ep.standby_state import ( ...@@ -29,8 +29,10 @@ from vllm.distributed.elastic_ep.standby_state import (
get_standby_ep_group, get_standby_ep_group,
pop_standby_groups, pop_standby_groups,
) )
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
_replace_active_groups, _replace_active_groups,
get_eplb_group,
prepare_communication_buffer_for_model, prepare_communication_buffer_for_model,
) )
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
...@@ -411,6 +413,13 @@ class ElasticEPScalingExecutor: ...@@ -411,6 +413,13 @@ class ElasticEPScalingExecutor:
module.quant_method = module.quant_method.old_quant_method module.quant_method = module.quant_method.old_quant_method
module.runner = module._init_runner() module.runner = module._init_runner()
prepare_communication_buffer_for_model(self.worker.model_runner.model) prepare_communication_buffer_for_model(self.worker.model_runner.model)
eplb_model_state.communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend=parallel_config.eplb_config.communicator,
expert_weights=model.expert_weights[0],
)
if ( if (
self.worker.vllm_config.compilation_config.mode self.worker.vllm_config.compilation_config.mode
== CompilationMode.STOCK_TORCH_COMPILE == CompilationMode.STOCK_TORCH_COMPILE
......
...@@ -98,6 +98,8 @@ async def transfer_run_periodically( ...@@ -98,6 +98,8 @@ async def transfer_run_periodically(
assert state.is_async assert state.is_async
for model_state in state.model_states.values(): for model_state in state.model_states.values():
# Set the async worker's CUDA stream on the communicator
model_state.communicator.set_stream(cuda_stream)
rebalancing_algorithm_executed = False rebalancing_algorithm_executed = False
physical_to_logical_map_cpu = None physical_to_logical_map_cpu = None
current_num_layers = model_state.model.num_moe_layers current_num_layers = model_state.model.num_moe_layers
...@@ -157,6 +159,7 @@ async def transfer_run_periodically( ...@@ -157,6 +159,7 @@ async def transfer_run_periodically(
expert_weights=model_state.model.expert_weights[layer_idx], expert_weights=model_state.model.expert_weights[layer_idx],
expert_weights_buffer=model_state.expert_buffer, expert_weights_buffer=model_state.expert_buffer,
ep_group=eplb_group, ep_group=eplb_group,
communicator=model_state.communicator,
is_profile=is_profile, is_profile=is_profile,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
) )
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
EPLB communicator implementations and factory.
"""
from abc import ABC, abstractmethod
from collections.abc import Sequence
import torch
from torch.distributed import (
P2POp,
ProcessGroup,
batch_isend_irecv,
)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import (
ncclDataTypeEnum,
)
from vllm.distributed.parallel_state import GroupCoordinator, is_local_first_rank
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
logger = init_logger(__name__)
class EplbCommunicator(ABC):
"""Abstract EPLB communicator for expert weight transfers."""
@abstractmethod
def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None:
pass
@abstractmethod
def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None:
pass
@abstractmethod
def execute(self) -> None:
pass
def set_stream(self, cuda_stream: torch.cuda.Stream | None) -> None:
self._cuda_stream = cuda_stream
def _log_initialized(self) -> None:
if is_local_first_rank():
logger.info("Initialized EPLB communicator: %s.", self.__class__.__name__)
class TorchDistNcclEplbCommunicator(EplbCommunicator):
"""EPLB communicator backed by torch.distributed isend/irecv."""
def __init__(
self,
ep_group: ProcessGroup,
cuda_stream: torch.cuda.Stream | None = None,
) -> None:
self._ep_group = ep_group
self._cuda_stream = cuda_stream
self._p2p_ops: list[P2POp] = []
self._log_initialized()
def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None:
self._p2p_ops.append(
P2POp(
torch.distributed.isend,
tensor,
dst_rank,
self._ep_group,
)
)
def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None:
self._p2p_ops.append(
P2POp(
torch.distributed.irecv,
tensor,
src_rank,
self._ep_group,
)
)
def execute(self) -> None:
if not self._p2p_ops:
return
try:
with torch.cuda.stream(self._cuda_stream):
reqs = batch_isend_irecv(self._p2p_ops)
for req in reqs:
req.wait()
finally:
self._p2p_ops.clear()
class TorchDistGlooStagedEplbCommunicator(EplbCommunicator):
"""EPLB communicator using gloo P2P with CPU staging."""
def __init__(
self,
cpu_group: ProcessGroup,
cuda_stream: torch.cuda.Stream | None = None,
) -> None:
self._cpu_group = cpu_group
self._cuda_stream = cuda_stream
self._ops: list[tuple[str, torch.Tensor, int]] = []
self._log_initialized()
def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None:
self._ops.append(("send", tensor, dst_rank))
def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None:
self._ops.append(("recv", tensor, src_rank))
def execute(self) -> None:
if not self._ops:
return
p2p_ops: list[P2POp] = []
recv_staging: list[tuple[torch.Tensor, torch.Tensor]] = []
def build_ops() -> None:
for op, tensor, peer_rank in self._ops:
if op == "send":
cpu_tensor = tensor.to(device="cpu", non_blocking=True)
p2p_ops.append(
P2POp(
torch.distributed.isend,
cpu_tensor,
peer_rank,
self._cpu_group,
)
)
continue
cpu_tensor = torch.empty_like(tensor, device="cpu")
p2p_ops.append(
P2POp(
torch.distributed.irecv,
cpu_tensor,
peer_rank,
self._cpu_group,
)
)
recv_staging.append((tensor, cpu_tensor))
try:
with torch.cuda.stream(self._cuda_stream):
build_ops()
finally:
self._ops.clear()
# Wait for all D2H copies to finish
# before issuing gloo batch_isend_irecv operations.
if self._cuda_stream is not None:
self._cuda_stream.synchronize()
else:
torch.cuda.current_stream().synchronize()
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
if not recv_staging:
return
with torch.cuda.stream(self._cuda_stream):
for dst_tensor, cpu_tensor in recv_staging:
dst_tensor.copy_(cpu_tensor, non_blocking=True)
class PyNcclEplbCommunicator(EplbCommunicator):
"""EPLB communicator backed by PyNcclCommunicator using ncclSend/ncclRecv."""
def __init__(
self,
pynccl_comm: PyNcclCommunicator,
cuda_stream: torch.cuda.Stream | None = None,
) -> None:
self._pynccl_comm = pynccl_comm
self._cuda_stream = cuda_stream
self._group_started = False
self._log_initialized()
def _ensure_group_started(self) -> None:
if not self._group_started:
self._pynccl_comm.group_start()
self._group_started = True
def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None:
self._ensure_group_started()
self._pynccl_comm.send(tensor, dst_rank, stream=self._cuda_stream)
def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None:
self._ensure_group_started()
self._pynccl_comm.recv(tensor, src_rank, stream=self._cuda_stream)
def execute(self) -> None:
if self._group_started:
self._pynccl_comm.group_end()
self._group_started = False
def create_eplb_communicator(
group_coordinator: GroupCoordinator,
backend: str | None,
expert_weights: Sequence[torch.Tensor],
) -> EplbCommunicator:
# Keep a safe default for callers that have not resolved communicator yet.
if backend is None:
backend = "torch_nccl"
tensor_device_type = expert_weights[0].device.type if expert_weights else "cpu"
torch_group = (
group_coordinator.cpu_group
if tensor_device_type == "cpu"
else group_coordinator.device_group
)
def _create_pynccl() -> EplbCommunicator:
if tensor_device_type == "cpu":
raise RuntimeError(
"EPLB communicator 'pynccl' supports only cuda-like devices "
f"(got {tensor_device_type})."
)
unsupported_dtypes = sorted(
{
tensor.dtype
for tensor in expert_weights
if not ncclDataTypeEnum.supports_torch_dtype(tensor.dtype)
},
key=str,
)
if unsupported_dtypes:
raise RuntimeError(
"EPLB communicator 'pynccl' requested but expert weights contain "
"unsupported dtypes: "
f"({', '.join(str(dtype) for dtype in unsupported_dtypes)})."
)
device_comm = group_coordinator.device_communicator
pynccl_comm = (
getattr(device_comm, "pynccl_comm", None)
if device_comm is not None
else None
)
if pynccl_comm is None or pynccl_comm.disabled or not pynccl_comm.available:
raise RuntimeError("EPLB communicator 'pynccl' requested but unavailable.")
try:
return PyNcclEplbCommunicator(pynccl_comm=pynccl_comm)
except Exception as exc:
raise RuntimeError(
f"Failed to initialize PyNcclEplbCommunicator ({exc})."
) from exc
is_stateless = isinstance(group_coordinator, StatelessGroupCoordinator)
if is_stateless:
if backend not in ("torch_nccl", "pynccl"):
raise ValueError(
f"Elastic EP requires 'torch_nccl' or 'pynccl' EPLB communicator "
f"(got '{backend}'). torch_gloo is not supported with stateless groups."
)
if backend == "torch_nccl":
logger.warning(
"Stateless elastic EP requires PyNCCL backend. "
"Forcing EPLB communicator to 'pynccl'."
)
backend = "pynccl"
return _create_pynccl()
if backend == "torch_gloo":
return TorchDistGlooStagedEplbCommunicator(
cpu_group=group_coordinator.cpu_group,
)
elif backend == "torch_nccl":
return TorchDistNcclEplbCommunicator(ep_group=torch_group)
elif backend == "pynccl":
return _create_pynccl()
raise ValueError(f"Unknown EPLB communicator backend: {backend}")
...@@ -37,6 +37,7 @@ from torch.distributed import ProcessGroup, all_reduce ...@@ -37,6 +37,7 @@ from torch.distributed import ProcessGroup, all_reduce
from vllm.config import ModelConfig, ParallelConfig from vllm.config import ModelConfig, ParallelConfig
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_ep_group, get_ep_group,
get_eplb_group,
get_node_count, get_node_count,
in_the_same_node_as, in_the_same_node_as,
) )
...@@ -46,6 +47,7 @@ from vllm.logger import init_logger ...@@ -46,6 +47,7 @@ from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts from vllm.model_executor.models.interfaces import MixtureOfExperts
from .async_worker import start_async_worker from .async_worker import start_async_worker
from .eplb_communicator import EplbCommunicator, create_eplb_communicator
from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy
from .rebalance_execute import ( from .rebalance_execute import (
RecvMetadata, RecvMetadata,
...@@ -225,6 +227,10 @@ class EplbModelState: ...@@ -225,6 +227,10 @@ class EplbModelState:
""" """
CUDA device index for the async EPLB worker thread. CUDA device index for the async EPLB worker thread.
""" """
communicator: EplbCommunicator
"""
The communicator for expert weight transfers.
"""
new_physical_to_logical_map: torch.Tensor | None = None new_physical_to_logical_map: torch.Tensor | None = None
""" """
intermediate variable between `move_to_buffer` and `move_to_workspace`. intermediate variable between `move_to_buffer` and `move_to_workspace`.
...@@ -472,6 +478,12 @@ class EplbState: ...@@ -472,6 +478,12 @@ class EplbState:
self._init_should_record_tensor(model) self._init_should_record_tensor(model)
expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]] expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend=self.parallel_config.eplb_config.communicator,
expert_weights=model.expert_weights[0],
)
model_state = EplbModelState( model_state = EplbModelState(
physical_to_logical_map=physical_to_logical_map, physical_to_logical_map=physical_to_logical_map,
logical_to_physical_map=logical_to_physical_map, logical_to_physical_map=logical_to_physical_map,
...@@ -498,6 +510,7 @@ class EplbState: ...@@ -498,6 +510,7 @@ class EplbState:
recv_dst_rows=np.array([]), recv_dst_rows=np.array([]),
), ),
cuda_device_index=self.cuda_device_index, cuda_device_index=self.cuda_device_index,
communicator=communicator,
new_physical_to_logical_map=None, new_physical_to_logical_map=None,
) )
self.model_states[model_config.compute_hash()] = model_state self.model_states[model_config.compute_hash()] = model_state
...@@ -800,6 +813,7 @@ class EplbState: ...@@ -800,6 +813,7 @@ class EplbState:
new_physical_to_logical_map, new_physical_to_logical_map,
eplb_model_state.model.expert_weights, eplb_model_state.model.expert_weights,
ep_group, ep_group,
eplb_model_state.communicator,
is_profile, is_profile,
rank_mapping, rank_mapping,
) )
...@@ -923,11 +937,8 @@ class EplbState: ...@@ -923,11 +937,8 @@ class EplbState:
new_indices=new_indices, new_indices=new_indices,
ep_rank=ep_group.rank(), ep_rank=ep_group.rank(),
) )
# Record event after consuming buffer to signal async thread
# that it's safe to overwrite the intermediate buffer transferred_layer = model_state.layer_to_transfer
consumed_event = torch.cuda.Event()
consumed_event.record()
model_state.buffer_consumed_event = consumed_event
transferred_layer = model_state.layer_to_transfer transferred_layer = model_state.layer_to_transfer
assert model_state.new_physical_to_logical_map is not None assert model_state.new_physical_to_logical_map is not None
...@@ -936,6 +947,13 @@ class EplbState: ...@@ -936,6 +947,13 @@ class EplbState:
new_physical_to_logical_map=model_state.new_physical_to_logical_map, new_physical_to_logical_map=model_state.new_physical_to_logical_map,
layer=transferred_layer, layer=transferred_layer,
) )
# Record event after consuming buffer to signal async thread
# that it's safe to overwrite the intermediate buffer
consumed_event = torch.cuda.Event()
consumed_event.record()
model_state.buffer_consumed_event = consumed_event
# After the main thread consumes, advance layer_to_transfer # After the main thread consumes, advance layer_to_transfer
model_state.layer_to_transfer += 1 model_state.layer_to_transfer += 1
model_state.ep_buffer_ready = 0 model_state.ep_buffer_ready = 0
......
...@@ -21,6 +21,10 @@ def override_envs_for_eplb(parallel_config: ParallelConfig) -> None: ...@@ -21,6 +21,10 @@ def override_envs_for_eplb(parallel_config: ParallelConfig) -> None:
is_eplb_enabled = parallel_config.enable_eplb is_eplb_enabled = parallel_config.enable_eplb
async_eplb = parallel_config.eplb_config.use_async async_eplb = parallel_config.eplb_config.use_async
is_deepep_ll = parallel_config.all2all_backend == "deepep_low_latency" is_deepep_ll = parallel_config.all2all_backend == "deepep_low_latency"
is_nccl_based_eplb_communicator = parallel_config.eplb_config.communicator in (
"torch_nccl",
"pynccl",
)
# Override NCCL_MAX_CTAS to avoid hangs when using async EPLB with the # Override NCCL_MAX_CTAS to avoid hangs when using async EPLB with the
# DeepEP low-latency backend. # DeepEP low-latency backend.
...@@ -39,7 +43,13 @@ def override_envs_for_eplb(parallel_config: ParallelConfig) -> None: ...@@ -39,7 +43,13 @@ def override_envs_for_eplb(parallel_config: ParallelConfig) -> None:
# Limiting NCCL occupancy via NCCL_MAX_CTAS leaves space for the DeepEP # Limiting NCCL occupancy via NCCL_MAX_CTAS leaves space for the DeepEP
# cooperative kernel to launch and complete, breaking the deadlock. # cooperative kernel to launch and complete, breaking the deadlock.
# See: https://github.com/deepseek-ai/DeepEP/issues/496 # See: https://github.com/deepseek-ai/DeepEP/issues/496
if is_data_parallel and is_eplb_enabled and is_deepep_ll and async_eplb: if (
is_data_parallel
and is_eplb_enabled
and is_deepep_ll
and async_eplb
and is_nccl_based_eplb_communicator
):
current_value_str = os.getenv("NCCL_MAX_CTAS") current_value_str = os.getenv("NCCL_MAX_CTAS")
if current_value_str and current_value_str.isdigit(): if current_value_str and current_value_str.isdigit():
...@@ -49,6 +59,7 @@ def override_envs_for_eplb(parallel_config: ParallelConfig) -> None: ...@@ -49,6 +59,7 @@ def override_envs_for_eplb(parallel_config: ParallelConfig) -> None:
os.environ["NCCL_MAX_CTAS"] = str(override_value) os.environ["NCCL_MAX_CTAS"] = str(override_value)
logger.info_once( logger.info_once(
f"EPLB: Setting NCCL_MAX_CTAS={override_value} " f"EPLB: Setting NCCL_MAX_CTAS={override_value} "
"for expert parallel with EPLB and deepep_low_latency backend", "for expert parallel with NCCL-based EPLB communicator and "
"deepep_low_latency backend",
scope="global", scope="global",
) )
...@@ -11,19 +11,9 @@ from dataclasses import dataclass ...@@ -11,19 +11,9 @@ from dataclasses import dataclass
import numpy as np import numpy as np
import torch import torch
from torch.distributed import ( from torch.distributed import ProcessGroup, all_gather
P2POp,
ProcessGroup,
all_gather,
batch_isend_irecv,
get_global_rank,
)
from vllm.distributed.parallel_state import get_ep_group from .eplb_communicator import EplbCommunicator
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass @dataclass
...@@ -158,7 +148,8 @@ def move_to_buffer( ...@@ -158,7 +148,8 @@ def move_to_buffer(
expert_weights: Sequence[torch.Tensor], expert_weights: Sequence[torch.Tensor],
expert_weights_buffers: Sequence[torch.Tensor], expert_weights_buffers: Sequence[torch.Tensor],
cuda_stream: torch.cuda.Stream | None, cuda_stream: torch.cuda.Stream | None,
ep_group: ProcessGroup, ep_rank: int,
communicator: EplbCommunicator,
) -> MoveToBufferResult: ) -> MoveToBufferResult:
""" """
Rearranges expert weights during EPLB rebalancing. Rearranges expert weights during EPLB rebalancing.
...@@ -172,7 +163,8 @@ def move_to_buffer( ...@@ -172,7 +163,8 @@ def move_to_buffer(
expert_weights: Original expert weights for the layer. expert_weights: Original expert weights for the layer.
expert_weights_buffers: Intermediate buffers (one per tensor). expert_weights_buffers: Intermediate buffers (one per tensor).
cuda_stream: CUDA stream for async copies (can be None for sync mode). cuda_stream: CUDA stream for async copies (can be None for sync mode).
ep_group: Distributed process group for expert parallel comms. ep_rank: Rank of this process in expert parallel group.
communicator: EplbCommunicator instance for P2P communication.
Returns: Returns:
is_unchanged (np.ndarray): (num_local_experts,), True where an expert row is_unchanged (np.ndarray): (num_local_experts,), True where an expert row
...@@ -182,8 +174,6 @@ def move_to_buffer( ...@@ -182,8 +174,6 @@ def move_to_buffer(
RecvMetadata: Metadata needed for completing remote weight transfers. RecvMetadata: Metadata needed for completing remote weight transfers.
""" """
assert old_indices.shape == new_indices.shape assert old_indices.shape == new_indices.shape
ep_rank = ep_group.rank()
recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_) recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_)
send_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64) send_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
send_src_rows = np.full((num_local_experts,), -1, dtype=np.int32) send_src_rows = np.full((num_local_experts,), -1, dtype=np.int32)
...@@ -247,22 +237,9 @@ def move_to_buffer( ...@@ -247,22 +237,9 @@ def move_to_buffer(
expert = new_local_expert_ids[dst] expert = new_local_expert_ids[dst]
src_local = expert_to_src_map.get(expert, -1) src_local = expert_to_src_map.get(expert, -1)
if src_local != -1: if src_local != -1:
for w, b in zip(expert_weights, expert_weights_buffers): with torch.cuda.stream(cuda_stream):
b[dst].copy_(w[src_local], non_blocking=True) for w, b in zip(expert_weights, expert_weights_buffers):
b[dst].copy_(w[src_local], non_blocking=True)
p2p_ops: list[P2POp] = []
if isinstance(get_ep_group(), StatelessGroupCoordinator):
ep_group = get_ep_group()
is_stateless = True
else:
is_stateless = False
# Pre-compute global ranks mapping (only needed for non-stateless groups)
ep_size = ep_group.size()
if not is_stateless:
rank_to_global = {
rank: get_global_rank(ep_group, rank) for rank in range(ep_size)
}
# 2. Post sends # 2. Post sends
if send_count > 0: if send_count > 0:
...@@ -294,23 +271,8 @@ def move_to_buffer( ...@@ -294,23 +271,8 @@ def move_to_buffer(
if recver_pos < len(ranks_to_recv): if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos]) recv_ranks.append(ranks_to_recv[recver_pos])
for dst in recv_ranks: for dst in recv_ranks:
if is_stateless: for w in expert_weights:
for w in expert_weights: communicator.add_send(w[src], dst)
op = object.__new__(P2POp)
op.op = torch.distributed.isend
op.tensor = w[src]
op.group_peer = dst
p2p_ops.append(op)
else:
dst_global = rank_to_global[dst]
p2p_ops += [
P2POp(
torch.distributed.isend,
w[src],
dst_global,
)
for w in expert_weights
]
# 3. Post recvs # 3. Post recvs
if recv_count > 0: if recv_count > 0:
...@@ -339,40 +301,11 @@ def move_to_buffer( ...@@ -339,40 +301,11 @@ def move_to_buffer(
src = ranks_to_send[recver_pos // num_dst_per_sender] src = ranks_to_send[recver_pos // num_dst_per_sender]
else: else:
src = ranks_to_send[recver_pos - remainder_start] src = ranks_to_send[recver_pos - remainder_start]
if is_stateless: for b in expert_weights_buffers:
for b in expert_weights_buffers: communicator.add_recv(b[dst], src)
op = object.__new__(P2POp)
op.op = torch.distributed.irecv
op.tensor = b[dst]
op.group_peer = src
p2p_ops.append(op)
else:
src_global = rank_to_global[src]
p2p_ops += [
P2POp(
torch.distributed.irecv,
b[dst],
src_global,
)
for b in expert_weights_buffers
]
# 4. Execute the P2P operations. The real communication happens here. # 4. Execute the P2P operations. The real communication happens here.
if p2p_ops and cuda_stream is not None: communicator.execute()
with torch.cuda.stream(cuda_stream):
if is_stateless:
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
else:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
elif p2p_ops:
if is_stateless:
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
else:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
# wait for the communication to finish # wait for the communication to finish
return ( return (
is_unchanged, is_unchanged,
...@@ -471,6 +404,7 @@ async def transfer_layer( ...@@ -471,6 +404,7 @@ async def transfer_layer(
expert_weights: Sequence[torch.Tensor], expert_weights: Sequence[torch.Tensor],
expert_weights_buffer: Sequence[torch.Tensor], expert_weights_buffer: Sequence[torch.Tensor],
ep_group: ProcessGroup, ep_group: ProcessGroup,
communicator: EplbCommunicator,
is_profile: bool = False, is_profile: bool = False,
cuda_stream: torch.cuda.Stream | None = None, cuda_stream: torch.cuda.Stream | None = None,
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
...@@ -489,6 +423,7 @@ async def transfer_layer( ...@@ -489,6 +423,7 @@ async def transfer_layer(
For example, a linear layer may have up and down projection. For example, a linear layer may have up and down projection.
expert_weights_buffer: Intermediate buffers (one per weight tensor). expert_weights_buffer: Intermediate buffers (one per weight tensor).
ep_group: The device process group for expert parallelism. ep_group: The device process group for expert parallelism.
communicator: EplbCommunicator instance for P2P communication.
is_profile (bool): If `True`, do not perform any actual weight copy. is_profile (bool): If `True`, do not perform any actual weight copy.
This is used during profile run, where we only perform dummy This is used during profile run, where we only perform dummy
communications to reserve enough memory for the buffers. communications to reserve enough memory for the buffers.
...@@ -542,7 +477,8 @@ async def transfer_layer( ...@@ -542,7 +477,8 @@ async def transfer_layer(
expert_weights=expert_weights, expert_weights=expert_weights,
expert_weights_buffers=expert_weights_buffer, expert_weights_buffers=expert_weights_buffer,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
ep_group=ep_group, ep_rank=ep_group.rank(),
communicator=communicator,
) )
return is_unchanged, is_received_locally, recv_metadata return is_unchanged, is_received_locally, recv_metadata
...@@ -552,6 +488,7 @@ def rearrange_expert_weights_inplace( ...@@ -552,6 +488,7 @@ def rearrange_expert_weights_inplace(
new_global_expert_indices: torch.Tensor, new_global_expert_indices: torch.Tensor,
expert_weights: Sequence[Sequence[torch.Tensor]], expert_weights: Sequence[Sequence[torch.Tensor]],
ep_group: ProcessGroup, ep_group: ProcessGroup,
communicator: EplbCommunicator,
is_profile: bool = False, is_profile: bool = False,
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
) -> None: ) -> None:
...@@ -569,6 +506,7 @@ def rearrange_expert_weights_inplace( ...@@ -569,6 +506,7 @@ def rearrange_expert_weights_inplace(
For example, a linear layer may have up and down projection, For example, a linear layer may have up and down projection,
so weight_count = 2. Each weight's hidden size can be different. so weight_count = 2. Each weight's hidden size can be different.
ep_group: The device process group for expert parallelism. ep_group: The device process group for expert parallelism.
communicator: EplbCommunicator instance for P2P communication.
is_profile (bool): If `True`, do not perform any actual weight copy. is_profile (bool): If `True`, do not perform any actual weight copy.
This is used during profile run, where we only perform dummy This is used during profile run, where we only perform dummy
communications to reserve enough memory for the buffers. communications to reserve enough memory for the buffers.
...@@ -599,6 +537,7 @@ def rearrange_expert_weights_inplace( ...@@ -599,6 +537,7 @@ def rearrange_expert_weights_inplace(
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts) assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
ep_size = ep_group.size() ep_size = ep_group.size()
ep_rank = ep_group.rank()
assert num_physical_experts == ep_size * num_local_physical_experts assert num_physical_experts == ep_size * num_local_physical_experts
first_layer_weights = list(expert_weights[0]) first_layer_weights = list(expert_weights[0])
...@@ -635,7 +574,8 @@ def rearrange_expert_weights_inplace( ...@@ -635,7 +574,8 @@ def rearrange_expert_weights_inplace(
expert_weights=expert_weights[layer_idx], expert_weights=expert_weights[layer_idx],
expert_weights_buffers=weights_buffer, expert_weights_buffers=weights_buffer,
cuda_stream=None, cuda_stream=None,
ep_group=ep_group, ep_rank=ep_rank,
communicator=communicator,
) )
move_from_buffer( move_from_buffer(
...@@ -645,7 +585,7 @@ def rearrange_expert_weights_inplace( ...@@ -645,7 +585,7 @@ def rearrange_expert_weights_inplace(
is_received_locally=is_received_locally, is_received_locally=is_received_locally,
recv_metadata=recv_metadata, recv_metadata=recv_metadata,
new_indices=new_global_expert_indices_cpu[layer_idx], new_indices=new_global_expert_indices_cpu[layer_idx],
ep_rank=ep_group.rank(), ep_rank=ep_rank,
) )
......
...@@ -1690,11 +1690,7 @@ def initialize_model_parallel( ...@@ -1690,11 +1690,7 @@ def initialize_model_parallel(
# using torch.distributed in execution with torch.distributed in EPLB. # using torch.distributed in execution with torch.distributed in EPLB.
global _EPLB global _EPLB
assert _EPLB is None, "EPLB group is already initialized" assert _EPLB is None, "EPLB group is already initialized"
if ( if config.parallel_config.enable_eplb:
config is not None
and config.parallel_config is not None
and config.parallel_config.enable_eplb
):
if enable_elastic_ep: if enable_elastic_ep:
_EPLB = _init_stateless_group( _EPLB = _init_stateless_group(
group_ranks, group_ranks,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment