Unverified Commit 0d4ea3fb authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[core][distributed] use tcp store directly (#10275)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 112fa0bb
......@@ -43,12 +43,15 @@ def test_cuda_device_count_stateless():
def cpu_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
if rank <= 2:
pg2 = StatelessProcessGroup.create(
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
port=port2,
rank=rank,
world_size=3)
data = torch.tensor([rank])
data = pg1.broadcast_obj(data, src=2)
assert data.item() == 2
......@@ -62,14 +65,17 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):
def gpu_worker(rank, WORLD_SIZE, port1, port2):
torch.cuda.set_device(rank)
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
pynccl1 = PyNcclCommunicator(pg1, device=rank)
pynccl1.disabled = False
if rank <= 2:
pg2 = StatelessProcessGroup.create(
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
port=port2,
rank=rank,
world_size=3)
pynccl2 = PyNcclCommunicator(pg2, device=rank)
pynccl2.disabled = False
data = torch.tensor([rank]).cuda()
......@@ -89,7 +95,8 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
if rank == 2:
......@@ -101,7 +108,8 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2):
def allgather_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
data = pg1.all_gather_obj(rank)
......@@ -109,8 +117,6 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2):
pg1.barrier()
# TODO: investigate why this test is flaky. It hangs during initialization.
@pytest.mark.skip("Skip the test because it is flaky.")
@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize(
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
......
......@@ -9,7 +9,7 @@ from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
import torch
from torch.distributed.rendezvous import rendezvous
from torch.distributed import TCPStore
import vllm.envs as envs
from vllm.logger import init_logger
......@@ -97,7 +97,6 @@ class StatelessProcessGroup:
group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
"""
prefix: str
rank: int
world_size: int
store: torch._C._distributed_c10d.Store
......@@ -127,7 +126,7 @@ class StatelessProcessGroup:
def send_obj(self, obj: Any, dst: int):
"""Send an object to a destination rank."""
self.expire_data()
key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}"
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(obj))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
......@@ -147,8 +146,7 @@ class StatelessProcessGroup:
"""Receive an object from a source rank."""
obj = pickle.loads(
self.store.get(
f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}"
))
f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
self.recv_src_counter[src] += 1
return obj
......@@ -159,14 +157,14 @@ class StatelessProcessGroup:
"""
if self.rank == src:
self.expire_data()
key = (f"{self.prefix}/broadcast_from/{src}/"
key = (f"broadcast_from/{src}/"
f"{self.broadcast_send_counter}")
self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return obj
else:
key = (f"{self.prefix}/broadcast_from/{src}/"
key = (f"broadcast_from/{src}/"
f"{self.broadcast_recv_src_counter[src]}")
recv_obj = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
......@@ -194,7 +192,8 @@ class StatelessProcessGroup:
@staticmethod
def create(
init_method: str,
host: str,
port: int,
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
......@@ -214,15 +213,14 @@ class StatelessProcessGroup:
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
timeout = _DEFAULT_PG_TIMEOUT
store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)
store = TCPStore(
host_name=host,
port=port,
world_size=world_size,
is_master=(rank == 0),
)
return StatelessProcessGroup(
prefix=init_method,
rank=rank,
world_size=world_size,
store=store,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment