"vscode:/vscode.git/clone" did not exist on "fa0050db08660535368ec5ea41d313bdeb69909d"
Unverified Commit 0d4ea3fb authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[core][distributed] use tcp store directly (#10275)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 112fa0bb
...@@ -43,12 +43,15 @@ def test_cuda_device_count_stateless(): ...@@ -43,12 +43,15 @@ def test_cuda_device_count_stateless():
def cpu_worker(rank, WORLD_SIZE, port1, port2): def cpu_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank, rank=rank,
world_size=WORLD_SIZE) world_size=WORLD_SIZE)
if rank <= 2: if rank <= 2:
pg2 = StatelessProcessGroup.create( pg2 = StatelessProcessGroup.create(host="127.0.0.1",
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3) port=port2,
rank=rank,
world_size=3)
data = torch.tensor([rank]) data = torch.tensor([rank])
data = pg1.broadcast_obj(data, src=2) data = pg1.broadcast_obj(data, src=2)
assert data.item() == 2 assert data.item() == 2
...@@ -62,14 +65,17 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2): ...@@ -62,14 +65,17 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):
def gpu_worker(rank, WORLD_SIZE, port1, port2): def gpu_worker(rank, WORLD_SIZE, port1, port2):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank, rank=rank,
world_size=WORLD_SIZE) world_size=WORLD_SIZE)
pynccl1 = PyNcclCommunicator(pg1, device=rank) pynccl1 = PyNcclCommunicator(pg1, device=rank)
pynccl1.disabled = False pynccl1.disabled = False
if rank <= 2: if rank <= 2:
pg2 = StatelessProcessGroup.create( pg2 = StatelessProcessGroup.create(host="127.0.0.1",
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3) port=port2,
rank=rank,
world_size=3)
pynccl2 = PyNcclCommunicator(pg2, device=rank) pynccl2 = PyNcclCommunicator(pg2, device=rank)
pynccl2.disabled = False pynccl2.disabled = False
data = torch.tensor([rank]).cuda() data = torch.tensor([rank]).cuda()
...@@ -89,7 +95,8 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2): ...@@ -89,7 +95,8 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
def broadcast_worker(rank, WORLD_SIZE, port1, port2): def broadcast_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank, rank=rank,
world_size=WORLD_SIZE) world_size=WORLD_SIZE)
if rank == 2: if rank == 2:
...@@ -101,7 +108,8 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2): ...@@ -101,7 +108,8 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2):
def allgather_worker(rank, WORLD_SIZE, port1, port2): def allgather_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank, rank=rank,
world_size=WORLD_SIZE) world_size=WORLD_SIZE)
data = pg1.all_gather_obj(rank) data = pg1.all_gather_obj(rank)
...@@ -109,8 +117,6 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2): ...@@ -109,8 +117,6 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2):
pg1.barrier() pg1.barrier()
# TODO: investigate why this test is flaky. It hangs during initialization.
@pytest.mark.skip("Skip the test because it is flaky.")
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]) "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
......
...@@ -9,7 +9,7 @@ from collections import deque ...@@ -9,7 +9,7 @@ from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple from typing import Any, Deque, Dict, Optional, Sequence, Tuple
import torch import torch
from torch.distributed.rendezvous import rendezvous from torch.distributed import TCPStore
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -97,7 +97,6 @@ class StatelessProcessGroup: ...@@ -97,7 +97,6 @@ class StatelessProcessGroup:
group. Only use it to communicate metadata between processes. group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects. For data-plane communication, create NCCL-related objects.
""" """
prefix: str
rank: int rank: int
world_size: int world_size: int
store: torch._C._distributed_c10d.Store store: torch._C._distributed_c10d.Store
...@@ -127,7 +126,7 @@ class StatelessProcessGroup: ...@@ -127,7 +126,7 @@ class StatelessProcessGroup:
def send_obj(self, obj: Any, dst: int): def send_obj(self, obj: Any, dst: int):
"""Send an object to a destination rank.""" """Send an object to a destination rank."""
self.expire_data() self.expire_data()
key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}" key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(obj)) self.store.set(key, pickle.dumps(obj))
self.send_dst_counter[dst] += 1 self.send_dst_counter[dst] += 1
self.entries.append((key, time.time())) self.entries.append((key, time.time()))
...@@ -147,8 +146,7 @@ class StatelessProcessGroup: ...@@ -147,8 +146,7 @@ class StatelessProcessGroup:
"""Receive an object from a source rank.""" """Receive an object from a source rank."""
obj = pickle.loads( obj = pickle.loads(
self.store.get( self.store.get(
f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}" f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
))
self.recv_src_counter[src] += 1 self.recv_src_counter[src] += 1
return obj return obj
...@@ -159,14 +157,14 @@ class StatelessProcessGroup: ...@@ -159,14 +157,14 @@ class StatelessProcessGroup:
""" """
if self.rank == src: if self.rank == src:
self.expire_data() self.expire_data()
key = (f"{self.prefix}/broadcast_from/{src}/" key = (f"broadcast_from/{src}/"
f"{self.broadcast_send_counter}") f"{self.broadcast_send_counter}")
self.store.set(key, pickle.dumps(obj)) self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1 self.broadcast_send_counter += 1
self.entries.append((key, time.time())) self.entries.append((key, time.time()))
return obj return obj
else: else:
key = (f"{self.prefix}/broadcast_from/{src}/" key = (f"broadcast_from/{src}/"
f"{self.broadcast_recv_src_counter[src]}") f"{self.broadcast_recv_src_counter[src]}")
recv_obj = pickle.loads(self.store.get(key)) recv_obj = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1 self.broadcast_recv_src_counter[src] += 1
...@@ -194,7 +192,8 @@ class StatelessProcessGroup: ...@@ -194,7 +192,8 @@ class StatelessProcessGroup:
@staticmethod @staticmethod
def create( def create(
init_method: str, host: str,
port: int,
rank: int, rank: int,
world_size: int, world_size: int,
data_expiration_seconds: int = 3600, data_expiration_seconds: int = 3600,
...@@ -214,15 +213,14 @@ class StatelessProcessGroup: ...@@ -214,15 +213,14 @@ class StatelessProcessGroup:
can call `StatelessProcessGroup.create` to form a group, and then process A, B, can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group. C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa """ # noqa
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT store = TCPStore(
timeout = _DEFAULT_PG_TIMEOUT host_name=host,
port=port,
store, rank, world_size = next( world_size=world_size,
rendezvous(init_method, rank, world_size, timeout=timeout)) is_master=(rank == 0),
store.set_timeout(timeout) )
return StatelessProcessGroup( return StatelessProcessGroup(
prefix=init_method,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
store=store, store=store,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment