"examples/vscode:/vscode.git/clone" did not exist on "75f81750f3a9071b23f7f2d5c9f9f1c2cd0091b1"
Unverified Commit 8bdd8b5c authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

Enable symmetric memory all reduce by default only enabling for TP (#25070)


Signed-off-by: default avatarilmarkov <markovilya197@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent a8ffc4f0
...@@ -164,6 +164,7 @@ steps: ...@@ -164,6 +164,7 @@ steps:
- tests/v1/test_internal_lb_dp.py - tests/v1/test_internal_lb_dp.py
- tests/v1/test_hybrid_lb_dp.py - tests/v1/test_hybrid_lb_dp.py
- tests/v1/engine/test_engine_core_client.py - tests/v1/engine/test_engine_core_client.py
- tests/distributed/test_symm_mem_allreduce.py
commands: commands:
# test with torchrun tp=2 and external_dp=2 # test with torchrun tp=2 and external_dp=2
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
...@@ -188,6 +189,7 @@ steps: ...@@ -188,6 +189,7 @@ steps:
- pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py - pytest -v -s distributed/test_pynccl.py
- pytest -v -s distributed/test_events.py - pytest -v -s distributed/test_events.py
- pytest -v -s distributed/test_symm_mem_allreduce.py
# TODO: create a dedicated test section for multi-GPU example tests # TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests # when we have multiple distributed example tests
- pushd ../examples/offline_inference - pushd ../examples/offline_inference
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import queue
import random import random
import typing import typing
...@@ -10,26 +11,31 @@ import torch.distributed as dist ...@@ -10,26 +11,31 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators.cuda_communicator import ( from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator) CudaCommunicator)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, from vllm.distributed.parallel_state import (get_tp_group,
get_tp_group,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel) initialize_model_parallel)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
torch.manual_seed(42) torch.manual_seed(42)
random.seed(44) random.seed(44)
test_size_elements = 4 * 1024 * 1024 test_size_elements = 1024 * 1024
def symm_mem_allreduce_worker(local_rank: int, world_size: int): def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
monkeypatch = pytest.MonkeyPatch() monkeypatch = pytest.MonkeyPatch()
with monkeypatch.context() as m: config = VllmConfig(parallel_config=ParallelConfig(
tensor_parallel_size=world_size))
with monkeypatch.context() as m, set_current_vllm_config(config):
m.delenv("CUDA_VISIBLE_DEVICES", raising=False) m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
dtype = torch.bfloat16 dtype = torch.bfloat16
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
...@@ -51,22 +57,26 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int): ...@@ -51,22 +57,26 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int):
get_tp_group().device_communicator) get_tp_group().device_communicator)
symm_mem_comm = cuda_communicator.symm_mem_comm symm_mem_comm = cuda_communicator.symm_mem_comm
if symm_mem_comm is None or symm_mem_comm.disabled: if symm_mem_comm is None or symm_mem_comm.disabled:
pytest.skip("SymmMemCommunicator is not available or disabled.") # can't use skip under multiprocessing
q.put("SymmMemCommunicator is not available or disabled.")
return
inp_direct_symm_mem = torch.randint(1, inp_direct_symm_mem = torch.randint(1,
23, (test_size_elements, ), 23, (test_size_elements, ),
dtype=dtype, dtype=dtype,
device=device) device=device)
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
pytest.skip( # can't use skip under multiprocessing
q.put(
"SymmMemCommunicator isn't used for this world and input size." "SymmMemCommunicator isn't used for this world and input size."
) )
return
original_inp_direct_symm_mem = inp_direct_symm_mem.clone() original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
assert out_direct_symm_mem is not None assert out_direct_symm_mem is not None
group = get_tensor_model_parallel_group().device_group group = get_tp_group().device_group
dist.all_reduce(original_inp_direct_symm_mem, group=group) dist.all_reduce(original_inp_direct_symm_mem, group=group)
torch.testing.assert_close(out_direct_symm_mem, torch.testing.assert_close(out_direct_symm_mem,
original_inp_direct_symm_mem, original_inp_direct_symm_mem,
...@@ -100,9 +110,34 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, ...@@ -100,9 +110,34 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
world_size = tp_size * pipeline_parallel_size world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count(): if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.") pytest.skip("Not enough GPUs to run the test.")
q = mp.get_context('spawn').Queue()
mp.spawn(symm_mem_allreduce_worker,
args=(world_size, q),
nprocs=world_size)
try:
val = q.get(timeout=1)
except queue.Empty:
val = None
finally:
cleanup_dist_env_and_memory()
if val is not None:
pytest.skip(val)
# Enable SymmMemCommunicator
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size) @pytest.mark.skipif(
cleanup_dist_env_and_memory() not current_platform.is_cuda(),
reason="SymmMemAllreduce is only available for CUDA platforms.")
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
world_size = 4
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
# Verify that the DataParallel runs without error
engine_args = EngineArgs(model="distilbert/distilgpt2",
enforce_eager=True,
enable_prefix_caching=True,
data_parallel_size=2,
tensor_parallel_size=2,
data_parallel_backend="mp")
LLMEngine.from_engine_args(engine_args)
...@@ -30,18 +30,21 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -30,18 +30,21 @@ class CudaCommunicator(DeviceCommunicatorBase):
unique_name: str = ""): unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name) super().__init__(cpu_group, device, device_group, unique_name)
if "tp" not in unique_name: if "tp" not in unique_name:
# only tp uses custom allreduce # custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False use_custom_allreduce = False
use_torch_symm_mem = False
else: else:
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
_ENABLE_CUSTOM_ALL_REDUCE) _ENABLE_CUSTOM_ALL_REDUCE)
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
# ep does not use pynccl # ep does not use pynccl
use_pynccl = "ep" not in unique_name use_pynccl = "ep" not in unique_name
self.use_pynccl = use_pynccl self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
# lazy import to avoid documentation build error # lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import ( from vllm.distributed.device_communicators.custom_all_reduce import (
...@@ -65,7 +68,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -65,7 +68,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.ca_comm: Optional[CustomAllreduce] = None self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): if use_torch_symm_mem and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator( self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
......
...@@ -182,7 +182,7 @@ if TYPE_CHECKING: ...@@ -182,7 +182,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
...@@ -1370,7 +1370,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1370,7 +1370,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Whether to use pytorch symmetric memory for allreduce # Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM": "VLLM_ALLREDUCE_USE_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))),
# Allows vllm to find tuned config under customized folder # Allows vllm to find tuned config under customized folder
"VLLM_TUNED_CONFIG_FOLDER": "VLLM_TUNED_CONFIG_FOLDER":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment