Unverified Commit 5ec5eaf7 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

fix allreduce test (#4909)

parent 0d7fe866
import ctypes
import logging
import multiprocessing as mp
import random
import socket
import time
import unittest
from typing import Any, List, Optional
import ray
import sgl_kernel.allreduce as custom_ops
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm import _custom_ops as vllm_ops
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
logger = logging.getLogger(__name__)
def get_open_port() -> int:
# try ipv4
......@@ -33,22 +28,21 @@ def get_open_port() -> int:
def multi_process_parallel(
world_size: int,
cls: Any,
test_target: Any,
) -> None:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray.init(log_to_driver=True)
procs = []
distributed_init_port = get_open_port()
refs = []
for rank in range(world_size):
refs.append(test_target.remote(cls, world_size, rank, distributed_init_port))
ray.get(refs)
for i in range(world_size):
proc = mp.Process(
target=test_target,
args=(world_size, i, distributed_init_port),
)
proc.start()
procs.append(proc)
ray.shutdown()
for i in range(world_size):
procs[i].join()
assert procs[i].exitcode == 0
class TestCustomAllReduce(unittest.TestCase):
......@@ -95,13 +89,8 @@ class TestCustomAllReduce(unittest.TestCase):
for world_size in self.world_sizes:
if world_size > torch.cuda.device_count():
continue
multi_process_parallel(world_size, self, self.correctness)
def test_performance(self):
for world_size in self.world_sizes:
if world_size > torch.cuda.device_count():
continue
multi_process_parallel(world_size, self, self.performance)
multi_process_parallel(world_size, self.correctness)
print(f"custom allreduce tp = {world_size}: OK")
def init_custom_allreduce(self, rank, world_size, group):
buffer_max_size = 8 * 1024 * 1024
......@@ -137,37 +126,6 @@ class TestCustomAllReduce(unittest.TestCase):
self.free_shared_buffer(self.barrier_out_ptrs, group)
custom_ops.custom_dispose(self.custom_ptr)
def init_vllm_allreduce(self, rank, group):
self.vllm_rank = rank
self.vllm_max_size = 8 * 1024 * 1024
self.vllm_meta_ptrs = self.create_shared_buffer(
vllm_ops.meta_size() + self.vllm_max_size, group=group
)
self.vllm_buffer_ptrs = self.create_shared_buffer(
self.vllm_max_size, group=group
)
self.vllm_rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
)
self.vllm_ptr = vllm_ops.init_custom_ar(
self.vllm_meta_ptrs, self.vllm_rank_data, rank, True
)
vllm_ops.register_buffer(self.vllm_ptr, self.vllm_buffer_ptrs)
def vllm_allreduce(self, inp, out):
vllm_ops.all_reduce(
self.vllm_ptr,
inp,
out,
self.vllm_buffer_ptrs[self.vllm_rank],
self.vllm_max_size,
)
def free_vllm_allreduce(self, group):
vllm_ops.dispose(self.vllm_ptr)
self.free_shared_buffer(self.vllm_meta_ptrs, group)
self.free_shared_buffer(self.vllm_buffer_ptrs, group)
@staticmethod
def init_distributed_env(world_size, rank, distributed_init_port):
device = torch.device("cuda:0")
......@@ -184,7 +142,6 @@ class TestCustomAllReduce(unittest.TestCase):
return group
# compare result with torch.distributed
@ray.remote(num_gpus=1, max_calls=1)
def correctness(self, world_size, rank, distributed_init_port):
group = self.init_distributed_env(world_size, rank, distributed_init_port)
......@@ -205,40 +162,6 @@ class TestCustomAllReduce(unittest.TestCase):
self.free_custom_allreduce(group)
# compare performance with vllm
@ray.remote(num_gpus=1, max_calls=1)
def performance(self, world_size, rank, distributed_init_port):
group = self.init_distributed_env(world_size, rank, distributed_init_port)
self.init_vllm_allreduce(rank, group)
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
for sz in self.test_sizes:
inp1 = torch.randint(
1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device()
)
out1 = torch.empty_like(inp1)
test_loop = 5000
start = time.time()
for _ in range(test_loop):
self.custom_allreduce(inp1, out1)
elapse_custom = time.time() - start
start = time.time()
for _ in range(test_loop):
self.vllm_allreduce(inp1, out1)
elapse_vllm = time.time() - start
if rank == 0:
logger.warning(
f"test_size = {sz}, world_size = {world_size}, "
f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms, "
f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms "
)
self.free_custom_allreduce(group)
self.free_vllm_allreduce(group)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment