Unverified Commit e6de9784 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[core][distributed] add stateless process group (#10216)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 36fc439d
import pytest import pytest
import ray import ray
import torch import torch
import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.utils import stateless_init_process_group from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import (cuda_device_count_stateless, from vllm.utils import (cuda_device_count_stateless,
update_environment_variables) update_environment_variables)
...@@ -41,42 +41,45 @@ def test_cuda_device_count_stateless(): ...@@ -41,42 +41,45 @@ def test_cuda_device_count_stateless():
def cpu_worker(rank, WORLD_SIZE): def cpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29500", pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29500",
rank=rank, rank=rank,
world_size=WORLD_SIZE, world_size=WORLD_SIZE)
backend="gloo")
if rank <= 2: if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29501", pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29501",
rank=rank, rank=rank,
world_size=3, world_size=3)
backend="gloo")
data = torch.tensor([rank]) data = torch.tensor([rank])
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1) data = pg1.broadcast_obj(data, src=2)
assert data.item() == 2
if rank <= 2: if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2) data = torch.tensor([rank + 1])
item = data[0].item() data = pg2.broadcast_obj(data, src=2)
print(f"rank: {rank}, item: {item}") assert data.item() == 3
if rank == 3: pg2.barrier()
assert item == 6 pg1.barrier()
else:
assert item == 18
def gpu_worker(rank, WORLD_SIZE): def gpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29502", torch.cuda.set_device(rank)
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29502",
rank=rank, rank=rank,
world_size=WORLD_SIZE, world_size=WORLD_SIZE)
backend="nccl") pynccl1 = PyNcclCommunicator(pg1, device=rank)
pynccl1.disabled = False
if rank <= 2: if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29503", pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29503",
rank=rank, rank=rank,
world_size=3, world_size=3)
backend="nccl") pynccl2 = PyNcclCommunicator(pg2, device=rank)
torch.cuda.set_device(rank) pynccl2.disabled = False
data = torch.tensor([rank]).cuda() data = torch.tensor([rank]).cuda()
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1) pynccl1.all_reduce(data)
pg1.barrier()
torch.cuda.synchronize()
if rank <= 2: if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2) pynccl2.all_reduce(data)
pg2.barrier()
torch.cuda.synchronize()
item = data[0].item() item = data[0].item()
print(f"rank: {rank}, item: {item}") print(f"rank: {rank}, item: {item}")
if rank == 3: if rank == 3:
...@@ -85,9 +88,31 @@ def gpu_worker(rank, WORLD_SIZE): ...@@ -85,9 +88,31 @@ def gpu_worker(rank, WORLD_SIZE):
assert item == 18 assert item == 18
def broadcast_worker(rank, WORLD_SIZE):
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29504",
rank=rank,
world_size=WORLD_SIZE)
if rank == 2:
pg1.broadcast_obj("secret", src=2)
else:
obj = pg1.broadcast_obj(None, src=2)
assert obj == "secret"
pg1.barrier()
def allgather_worker(rank, WORLD_SIZE):
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29505",
rank=rank,
world_size=WORLD_SIZE)
data = pg1.all_gather_obj(rank)
assert data == list(range(WORLD_SIZE))
pg1.barrier()
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("worker", [cpu_worker, gpu_worker]) @pytest.mark.parametrize(
def test_stateless_init_process_group(worker): "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
def test_stateless_process_group(worker):
WORLD_SIZE = 4 WORLD_SIZE = 4
from multiprocessing import get_context from multiprocessing import get_context
ctx = get_context("fork") ctx = get_context("fork")
......
...@@ -9,6 +9,7 @@ from torch.distributed import ProcessGroup, ReduceOp ...@@ -9,6 +9,7 @@ from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.device_communicators.pynccl_wrapper import ( from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
ncclRedOpTypeEnum, ncclUniqueId) ncclRedOpTypeEnum, ncclUniqueId)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -18,7 +19,7 @@ class PyNcclCommunicator: ...@@ -18,7 +19,7 @@ class PyNcclCommunicator:
def __init__( def __init__(
self, self,
group: ProcessGroup, group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device], device: Union[int, str, torch.device],
library_path: Optional[str] = None, library_path: Optional[str] = None,
): ):
...@@ -33,13 +34,18 @@ class PyNcclCommunicator: ...@@ -33,13 +34,18 @@ class PyNcclCommunicator:
It is the caller's responsibility to make sure each communicator It is the caller's responsibility to make sure each communicator
is bind to a unique device. is bind to a unique device.
""" """
assert dist.is_initialized() if not isinstance(group, StatelessProcessGroup):
assert dist.get_backend(group) != dist.Backend.NCCL, ( assert dist.is_initialized()
"PyNcclCommunicator should be attached to a non-NCCL group.") assert dist.get_backend(group) != dist.Backend.NCCL, (
"PyNcclCommunicator should be attached to a non-NCCL group.")
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
else:
self.rank = group.rank
self.world_size = group.world_size
self.group = group self.group = group
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
# if world_size == 1, no need to create communicator # if world_size == 1, no need to create communicator
if self.world_size == 1: if self.world_size == 1:
...@@ -68,13 +74,17 @@ class PyNcclCommunicator: ...@@ -68,13 +74,17 @@ class PyNcclCommunicator:
else: else:
# construct an empty unique id # construct an empty unique id
self.unique_id = ncclUniqueId() self.unique_id = ncclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group) if not isinstance(group, StatelessProcessGroup):
# arg `src` in `broadcast` is the global rank tensor = torch.ByteTensor(list(self.unique_id.internal))
dist.broadcast(tensor, src=ranks[0], group=group) ranks = dist.get_process_group_ranks(group)
byte_list = tensor.tolist() # arg `src` in `broadcast` is the global rank
for i, byte in enumerate(byte_list): dist.broadcast(tensor, src=ranks[0], group=group)
self.unique_id.internal[i] = byte byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
else:
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
if isinstance(device, int): if isinstance(device, int):
device = torch.device(f"cuda:{device}") device = torch.device(f"cuda:{device}")
elif isinstance(device, str): elif isinstance(device, str):
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# Adapted from # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Sequence, Tuple import dataclasses
import pickle
import time
from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
import torch import torch
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs import vllm.envs as envs
...@@ -91,69 +91,139 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int, ...@@ -91,69 +91,139 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
return (start_layer, end_layer) return (start_layer, end_layer)
def stateless_init_process_group(init_method: str, rank: int, world_size: int, @dataclasses.dataclass
backend: str) -> ProcessGroup: class StatelessProcessGroup:
"""A replacement for `torch.distributed.init_process_group` that does not """A dataclass to hold a metadata store, and the rank, world_size of the
pollute the global state. group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
If we have process A and process B called `torch.distributed.init_process_group` """
to form a group, and then we want to form another group with process A, B, C, prefix: str
D, it is not possible in PyTorch, because process A and process B have already rank: int
formed a group, and process C and process D cannot join that group. This world_size: int
function is a workaround for this issue. store: torch._C._distributed_c10d.Store
data_expiration_seconds: int = 3600 # 1 hour
`torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `ProcessGroup` object that can be used # dst rank -> counter
for collective communication. With this function, process A and process B send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
can call `stateless_init_process_group` to form a group, and then process A, B, # src rank -> counter
C, and D can call `stateless_init_process_group` to form another group. recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
""" # noqa broadcast_send_counter: int = 0
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(
backend = Backend(backend) # it is basically string default_factory=dict)
timeout = _get_default_timeout(backend)
# A deque to store the data entries, with key and timestamp.
store, rank, world_size = next( entries: Deque[Tuple[str,
rendezvous(init_method, rank, world_size, timeout=timeout)) float]] = dataclasses.field(default_factory=deque)
store.set_timeout(timeout)
def __post_init__(self):
group_rank = rank assert self.rank < self.world_size
group_size = world_size self.send_dst_counter = {i: 0 for i in range(self.world_size)}
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
# Use a PrefixStore to avoid accidental overrides of keys used by self.broadcast_recv_src_counter = {
# different systems (e.g. RPC) in case the store is multi-tenant. i: 0
prefix_store = PrefixStore(init_method, store) for i in range(self.world_size)
}
pg_options = ProcessGroup.Options(backend=backend, timeout=timeout)
def send_obj(self, obj: Any, dst: int):
pg: ProcessGroup = ProcessGroup( """Send an object to a destination rank."""
prefix_store, self.expire_data()
group_rank, key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}"
group_size, self.store.set(key, pickle.dumps(obj))
pg_options, self.send_dst_counter[dst] += 1
) self.entries.append((key, time.time()))
if backend == "gloo": def expire_data(self):
from torch.distributed.distributed_c10d import ProcessGroupGloo """Expire data that is older than `data_expiration_seconds` seconds."""
backend_class = ProcessGroupGloo(prefix_store, while self.entries:
group_rank, # check the oldest entry
group_size, key, timestamp = self.entries[0]
timeout=timeout) if time.time() - timestamp > self.data_expiration_seconds:
backend_type = ProcessGroup.BackendType.GLOO self.store.delete_key(key)
device = torch.device("cpu") self.entries.popleft()
elif backend == "nccl": else:
assert is_nccl_available() break
from torch.distributed.distributed_c10d import ProcessGroupNCCL
def recv_obj(self, src: int) -> Any:
backend_options = ProcessGroupNCCL.Options() """Receive an object from a source rank."""
backend_options._timeout = timeout obj = pickle.loads(
self.store.get(
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}"
backend_options) ))
backend_type = ProcessGroup.BackendType.NCCL self.recv_src_counter[src] += 1
device = torch.device("cuda") return obj
backend_class._set_sequence_number_for_group() def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
"""Broadcast an object from a source rank to all other ranks.
pg._register_backend(device, backend_type, backend_class) It does not clean up after all ranks have received the object.
Use it for limited times, e.g., for initialization.
return pg """
if self.rank == src:
self.expire_data()
key = (f"{self.prefix}/broadcast_from/{src}/"
f"{self.broadcast_send_counter}")
self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return obj
else:
key = (f"{self.prefix}/broadcast_from/{src}/"
f"{self.broadcast_recv_src_counter[src]}")
recv_obj = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
return recv_obj
def all_gather_obj(self, obj: Any) -> list[Any]:
"""All gather an object from all ranks."""
gathered_objs = []
for i in range(self.world_size):
if i == self.rank:
gathered_objs.append(obj)
self.broadcast_obj(obj, src=self.rank)
else:
recv_obj = self.broadcast_obj(None, src=i)
gathered_objs.append(recv_obj)
return gathered_objs
def barrier(self):
"""A barrier to synchronize all ranks."""
for i in range(self.world_size):
if i == self.rank:
self.broadcast_obj(None, src=self.rank)
else:
self.broadcast_obj(None, src=i)
@staticmethod
def create(
init_method: str,
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
If we have process A and process B called `torch.distributed.init_process_group`
to form a group, and then we want to form another group with process A, B, C,
D, it is not possible in PyTorch, because process A and process B have already
formed a group, and process C and process D cannot join that group. This
function is a workaround for this issue.
`torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `StatelessProcessGroup` object that can be
used for exchanging metadata. With this function, process A and process B
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
timeout = _DEFAULT_PG_TIMEOUT
store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)
return StatelessProcessGroup(
prefix=init_method,
rank=rank,
world_size=world_size,
store=store,
data_expiration_seconds=data_expiration_seconds)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment