Unverified Commit 5ec5eaf7 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

fix allreduce test (#4909)

parent 0d7fe866
import ctypes import ctypes
import logging import multiprocessing as mp
import random import random
import socket import socket
import time
import unittest import unittest
from typing import Any, List, Optional from typing import Any, List, Optional
import ray
import sgl_kernel.allreduce as custom_ops import sgl_kernel.allreduce as custom_ops
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from vllm import _custom_ops as vllm_ops
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
logger = logging.getLogger(__name__)
def get_open_port() -> int: def get_open_port() -> int:
# try ipv4 # try ipv4
...@@ -33,22 +28,21 @@ def get_open_port() -> int: ...@@ -33,22 +28,21 @@ def get_open_port() -> int:
def multi_process_parallel( def multi_process_parallel(
world_size: int, world_size: int,
cls: Any,
test_target: Any, test_target: Any,
) -> None: ) -> None:
# Using ray helps debugging the error when it failed procs = []
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray.init(log_to_driver=True)
distributed_init_port = get_open_port() distributed_init_port = get_open_port()
refs = [] for i in range(world_size):
for rank in range(world_size): proc = mp.Process(
refs.append(test_target.remote(cls, world_size, rank, distributed_init_port)) target=test_target,
ray.get(refs) args=(world_size, i, distributed_init_port),
)
proc.start()
procs.append(proc)
ray.shutdown() for i in range(world_size):
procs[i].join()
assert procs[i].exitcode == 0
class TestCustomAllReduce(unittest.TestCase): class TestCustomAllReduce(unittest.TestCase):
...@@ -95,13 +89,8 @@ class TestCustomAllReduce(unittest.TestCase): ...@@ -95,13 +89,8 @@ class TestCustomAllReduce(unittest.TestCase):
for world_size in self.world_sizes: for world_size in self.world_sizes:
if world_size > torch.cuda.device_count(): if world_size > torch.cuda.device_count():
continue continue
multi_process_parallel(world_size, self, self.correctness) multi_process_parallel(world_size, self.correctness)
print(f"custom allreduce tp = {world_size}: OK")
def test_performance(self):
for world_size in self.world_sizes:
if world_size > torch.cuda.device_count():
continue
multi_process_parallel(world_size, self, self.performance)
def init_custom_allreduce(self, rank, world_size, group): def init_custom_allreduce(self, rank, world_size, group):
buffer_max_size = 8 * 1024 * 1024 buffer_max_size = 8 * 1024 * 1024
...@@ -137,37 +126,6 @@ class TestCustomAllReduce(unittest.TestCase): ...@@ -137,37 +126,6 @@ class TestCustomAllReduce(unittest.TestCase):
self.free_shared_buffer(self.barrier_out_ptrs, group) self.free_shared_buffer(self.barrier_out_ptrs, group)
custom_ops.custom_dispose(self.custom_ptr) custom_ops.custom_dispose(self.custom_ptr)
def init_vllm_allreduce(self, rank, group):
self.vllm_rank = rank
self.vllm_max_size = 8 * 1024 * 1024
self.vllm_meta_ptrs = self.create_shared_buffer(
vllm_ops.meta_size() + self.vllm_max_size, group=group
)
self.vllm_buffer_ptrs = self.create_shared_buffer(
self.vllm_max_size, group=group
)
self.vllm_rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
)
self.vllm_ptr = vllm_ops.init_custom_ar(
self.vllm_meta_ptrs, self.vllm_rank_data, rank, True
)
vllm_ops.register_buffer(self.vllm_ptr, self.vllm_buffer_ptrs)
def vllm_allreduce(self, inp, out):
vllm_ops.all_reduce(
self.vllm_ptr,
inp,
out,
self.vllm_buffer_ptrs[self.vllm_rank],
self.vllm_max_size,
)
def free_vllm_allreduce(self, group):
vllm_ops.dispose(self.vllm_ptr)
self.free_shared_buffer(self.vllm_meta_ptrs, group)
self.free_shared_buffer(self.vllm_buffer_ptrs, group)
@staticmethod @staticmethod
def init_distributed_env(world_size, rank, distributed_init_port): def init_distributed_env(world_size, rank, distributed_init_port):
device = torch.device("cuda:0") device = torch.device("cuda:0")
...@@ -184,7 +142,6 @@ class TestCustomAllReduce(unittest.TestCase): ...@@ -184,7 +142,6 @@ class TestCustomAllReduce(unittest.TestCase):
return group return group
# compare result with torch.distributed # compare result with torch.distributed
@ray.remote(num_gpus=1, max_calls=1)
def correctness(self, world_size, rank, distributed_init_port): def correctness(self, world_size, rank, distributed_init_port):
group = self.init_distributed_env(world_size, rank, distributed_init_port) group = self.init_distributed_env(world_size, rank, distributed_init_port)
...@@ -205,40 +162,6 @@ class TestCustomAllReduce(unittest.TestCase): ...@@ -205,40 +162,6 @@ class TestCustomAllReduce(unittest.TestCase):
self.free_custom_allreduce(group) self.free_custom_allreduce(group)
# compare performance with vllm
@ray.remote(num_gpus=1, max_calls=1)
def performance(self, world_size, rank, distributed_init_port):
group = self.init_distributed_env(world_size, rank, distributed_init_port)
self.init_vllm_allreduce(rank, group)
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
for sz in self.test_sizes:
inp1 = torch.randint(
1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device()
)
out1 = torch.empty_like(inp1)
test_loop = 5000
start = time.time()
for _ in range(test_loop):
self.custom_allreduce(inp1, out1)
elapse_custom = time.time() - start
start = time.time()
for _ in range(test_loop):
self.vllm_allreduce(inp1, out1)
elapse_vllm = time.time() - start
if rank == 0:
logger.warning(
f"test_size = {sz}, world_size = {world_size}, "
f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms, "
f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms "
)
self.free_custom_allreduce(group)
self.free_vllm_allreduce(group)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment