Unverified Commit 9fccda31 authored by Adarsh Shirawalmath's avatar Adarsh Shirawalmath Committed by GitHub
Browse files

[Feature] use pytest for sgl-kernel (#4896)

parent 4ede6770
......@@ -80,7 +80,8 @@ jobs:
- name: Install
run: |
pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.6.4.post1
bash scripts/ci_install_dependency.sh
pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.7.2
pip3 uninstall sgl-kernel -y || true
pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps
pip3 list | grep sgl-kernel
......@@ -89,7 +90,7 @@ jobs:
timeout-minutes: 30
run: |
cd sgl-kernel
find tests -name "test_*.py" | xargs -n 1 python3
pytest tests/
- name: Uninstall dependencies
run: |
......
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import verify_tree_greedy
......@@ -85,14 +86,14 @@ def test_verify_tree_greedy():
print(f"{accept_index=}")
print(f"{accept_token_num=}")
return predicts, accept_index, accept_token_num
if __name__ == "__main__":
predicts, accept_index, accept_token_num = test_verify_tree_greedy()
# Check the expected output.
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 3, 4, 5],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [3, 2]
if __name__ == "__main__":
pytest.main([__file__])
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import tree_speculative_sampling_target_only
......@@ -97,26 +98,21 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
print(f"{accept_index=}")
print(f"{accept_token_num=}")
return predicts, accept_index, accept_token_num
if threshold_single == 1 and threshold_acc == 1:
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 3, 4, 5],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [3, 2]
elif threshold_single == 0 and threshold_acc == 0:
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 1, 2, -1],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [2, 2]
if __name__ == "__main__":
predicts, accept_index, accept_token_num = (
test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1)
)
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 3, 4, 5],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [3, 2]
predicts, accept_index, accept_token_num = (
test_tree_speculative_sampling_target_only(threshold_single=0, threshold_acc=0)
)
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 1, 2, -1],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [2, 2]
pytest.main([__file__])
......@@ -128,5 +128,4 @@ def test_awq_dequant_compare_implementations(
if __name__ == "__main__":
# Run the specific test function directly
pytest.main([__file__])
import unittest
import pytest
import torch
from sgl_kernel import cublas_grouped_gemm
def torch_grouped_gemm(a_array, b_array, out_dtype):
c_array = []
for a, b in zip(a_array, b_array):
c_array.append(torch.matmul(a, b.t()).to(out_dtype))
return c_array
class TestGroupedGemm(unittest.TestCase):
def _test_accuracy(self, Ms, Ns, Ks, out_dtype):
group_count = len(Ms)
a_array = []
b_array = []
c_array_cublas = []
for i in range(group_count):
M, N, K = Ms[i], Ns[i], Ks[i]
a_array.append(torch.randn((M, K), device="cuda", dtype=out_dtype) * 5)
b_array.append(torch.randn((N, K), device="cuda", dtype=out_dtype) * 5)
c_array_cublas.append(torch.empty((M, N), device="cuda", dtype=out_dtype))
c_array_torch = torch_grouped_gemm(a_array, b_array, out_dtype)
cublas_grouped_gemm(a_array, b_array, c_array_cublas, out_dtype)
for i in range(group_count):
M, N, K = Ms[i], Ns[i], Ks[i]
torch.testing.assert_close(c_array_torch[i], c_array_cublas[i])
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
def test_accuracy(self):
Ms = [1, 16, 32, 256, 1024]
Ns = [2, 16, 128, 256, 4096]
Ks = [3, 16, 32, 512, 8192]
out_dtypes = [torch.float16, torch.bfloat16]
for out_dtype in out_dtypes:
self._test_accuracy(Ms, Ns, Ks, out_dtype)
return [torch.matmul(a, b.t()).to(out_dtype) for a, b in zip(a_array, b_array)]
skip_condition = not torch.cuda.is_available() or (
torch.version.cuda is None
or tuple(map(int, torch.version.cuda.split("."))) < (12, 5)
)
@pytest.mark.skipif(
skip_condition, reason="CUDA not available or CUDA version lower than 12.5"
)
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("M", [1, 16, 32, 256, 1024])
@pytest.mark.parametrize("N", [2, 16, 128, 256, 4096])
@pytest.mark.parametrize("K", [3, 16, 32, 512, 8192])
def test_grouped_gemm_accuracy(out_dtype, M, N, K):
a = torch.randn((M, K), device="cuda", dtype=out_dtype) * 5
b = torch.randn((N, K), device="cuda", dtype=out_dtype) * 5
expected = torch.matmul(a, b.t()).to(out_dtype)
a_array = [a]
b_array = [b]
c_array = [torch.empty((M, N), device="cuda", dtype=out_dtype)]
result_torch = torch_grouped_gemm(a_array, b_array, out_dtype)[0]
cublas_grouped_gemm(a_array, b_array, c_array, out_dtype)
torch.testing.assert_close(result_torch, expected)
torch.testing.assert_close(c_array[0], expected)
if __name__ == "__main__":
if torch.cuda.is_available():
cuda_version = tuple(map(int, torch.version.cuda.split(".")))
if cuda_version >= (12, 5):
unittest.main()
else:
print(f"Cuda version {cuda_version} lower than 12.5, not executing tests.")
pytest.main([__file__])
import unittest
import os
import random
from typing import Optional, Type
import pytest
import torch
from sgl_kernel import fp8_blockwise_scaled_mm
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
......@@ -23,7 +24,6 @@ def baseline_scaled_mm(
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match the
# the target shape by repeating the data along that dimension (broadcasting)
......@@ -51,62 +51,44 @@ def baseline_scaled_mm(
scale_a = group_broadcast(scale_a, a.shape)
scale_b = group_broadcast(scale_b, b.shape)
output = torch.mm(
(scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
).to(out_dtype)
if bias is not None:
output = output + bias
return output
class TestFp8Gemm(unittest.TestCase):
def _test_accuracy_once(self, M, N, K, out_dtype, device):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
scale_a_group_shape = (1, 128)
scale_b_group_shape = (128, 128)
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
scale_a = scale_a.t().contiguous().t()
scale_b = scale_b.t().contiguous().t()
o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
def test_accuracy(self):
Ms = [1, 128, 512, 1024, 4096]
Ns = [128, 512, 1024, 4096]
Ks = [512, 1024, 4096, 8192, 16384]
out_dtypes = [torch.bfloat16, torch.float16]
for M in Ms:
for N in Ns:
for K in Ks:
for out_dtype in out_dtypes:
self._test_accuracy_once(M, N, K, out_dtype, "cuda")
def _test_accuracy_once(M, N, K, out_dtype, device):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
scale_a_group_shape = (1, 128)
scale_b_group_shape = (128, 128)
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
scale_a = scale_a.t().contiguous().t()
scale_b = scale_b.t().contiguous().t()
o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096])
@pytest.mark.parametrize("N", [128, 512, 1024, 4096])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
def test_accuracy(M, N, K, out_dtype):
_test_accuracy_once(M, N, K, out_dtype, "cuda")
if __name__ == "__main__":
unittest.main()
pytest.main([__file__])
import unittest
import pytest
import torch
from sgl_kernel import fp8_scaled_mm
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
o = o.to(torch.float32)
temp1 = o * scale_a.view(-1, 1)
temp2 = temp1 * scale_b.view(1, -1)
final = temp2.to(out_dtype)
if bias is not None:
final = final + bias.view(1, -1)
return final
class TestFp8Gemm(unittest.TestCase):
def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001
if with_bias:
bias = torch.randn((N,), device=device, dtype=out_dtype)
else:
bias = None
o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
b_fp8 = b_fp8.t()
o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
def test_accuracy(self):
Ms = [1, 128, 512, 1024, 4096]
Ns = [16, 128, 512, 1024, 4096]
Ks = [512, 1024, 4096, 8192, 16384]
bias_opts = [True, False]
out_dtypes = [torch.bfloat16, torch.float16]
for M in Ms:
for N in Ns:
for K in Ks:
for with_bias in bias_opts:
for out_dtype in out_dtypes:
self._test_accuracy_once(
M, N, K, with_bias, out_dtype, "cuda"
)
def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001
if with_bias:
bias = torch.randn((N,), device=device, dtype=out_dtype)
else:
bias = None
b_fp8 = b_fp8.t()
o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096])
@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("with_bias", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
def test_accuracy(M, N, K, with_bias, out_dtype):
_test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda")
if __name__ == "__main__":
unittest.main()
pytest.main([__file__])
import unittest
import pytest
import torch
from sgl_kernel import int8_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
......@@ -18,39 +17,31 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
return o.to(out_dtype)
class TestInt8Gemm(unittest.TestCase):
def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device):
a = to_int8(torch.randn((M, K), device=device) * 5)
b = to_int8(torch.randn((N, K), device=device).t() * 5)
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
if with_bias:
bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10
else:
bias = None
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
torch.testing.assert_close(o, o1)
torch.testing.assert_close(o, o2)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
def test_accuracy(self):
Ms = [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]
Ns = [16, 128, 512, 1024, 4096, 8192, 16384]
Ks = [512, 1024, 4096, 8192, 16384]
bias_opts = [True, False]
out_dtypes = [torch.float16, torch.bfloat16]
for M in Ms:
for N in Ns:
for K in Ks:
for with_bias in bias_opts:
for out_dtype in out_dtypes:
self._test_accuracy_once(
M, N, K, with_bias, out_dtype, "cuda"
)
def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
a = to_int8(torch.randn((M, K), device=device) * 5)
b = to_int8(torch.randn((N, K), device=device).t() * 5)
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
if with_bias:
bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10
else:
bias = None
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
torch.testing.assert_close(o, o1)
torch.testing.assert_close(o, o2)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("with_bias", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
def test_accuracy(M, N, K, with_bias, out_dtype):
_test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda")
if __name__ == "__main__":
unittest.main()
pytest.main([__file__])
......@@ -51,5 +51,4 @@ def test_per_token_quant_compare_implementations(
if __name__ == "__main__":
# Run the specific test function directly
pytest.main([__file__])
......@@ -13,154 +13,185 @@ from torch.distributed import ProcessGroup
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
ranks = list(range(world_size))
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = dist.group.WORLD
buffer_max_size = 8 * 1024 * 1024
barrier_max_size = 8 * (24 + 2) * 8
buffer_ptrs = None
tmp_result_buffer_ptrs = None
barrier_in_ptrs = None
barrier_out_ptrs = None
custom_ptr = None
try:
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
buffer_max_size, group=group
)
tmp_result_buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
buffer_max_size, group=group
)
barrier_in_ptrs = TestCustomAllReduce.create_shared_buffer(
barrier_max_size, group=group
)
barrier_out_ptrs = TestCustomAllReduce.create_shared_buffer(
barrier_max_size, group=group
)
rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)
custom_ptr = custom_ops.init_custom_reduce(
rank,
world_size,
rank_data,
buffer_ptrs,
tmp_result_buffer_ptrs,
barrier_in_ptrs,
barrier_out_ptrs,
)
test_loop = 10
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(test_loop):
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
inp1_ref = inp1.clone()
out1 = torch.empty_like(inp1)
custom_ops.custom_reduce(custom_ptr, inp1, out1)
dist.all_reduce(inp1_ref, group=group)
torch.testing.assert_close(out1, inp1_ref)
finally:
dist.barrier(group=group)
if custom_ptr is not None:
custom_ops.custom_dispose(custom_ptr)
if buffer_ptrs:
TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group)
if tmp_result_buffer_ptrs:
TestCustomAllReduce.free_shared_buffer(tmp_result_buffer_ptrs, group)
if barrier_in_ptrs:
TestCustomAllReduce.free_shared_buffer(barrier_in_ptrs, group)
if barrier_out_ptrs:
TestCustomAllReduce.free_shared_buffer(barrier_out_ptrs, group)
dist.destroy_process_group(group=group)
def get_open_port() -> int:
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
s.bind(("::1", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int,
test_target: Any,
world_size: int, test_target: Any, target_args: tuple = ()
) -> None:
mp.set_start_method("spawn", force=True)
procs = []
distributed_init_port = get_open_port()
for i in range(world_size):
proc = mp.Process(
target=test_target,
args=(world_size, i, distributed_init_port),
)
proc_args = (world_size, i, distributed_init_port) + target_args
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
proc.start()
procs.append(proc)
for i in range(world_size):
procs[i].join()
assert procs[i].exitcode == 0
assert (
procs[i].exitcode == 0
), f"Process {i} failed with exit code {procs[i].exitcode}"
class TestCustomAllReduce(unittest.TestCase):
@classmethod
def setUpClass(cls):
random.seed(42)
cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
cls.world_sizes = [2, 4, 8]
test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
world_sizes = [2, 4, 8]
@staticmethod
def create_shared_buffer(
size_in_bytes: int, group: Optional[ProcessGroup] = None
) -> List[int]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
if group is None:
group = dist.group.WORLD
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
handle_bytes = ctypes.string_at(ctypes.addressof(handle), ctypes.sizeof(handle))
input_tensor = torch.ByteTensor(list(handle_bytes)).to(f"cuda:{rank}")
gathered_tensors = [torch.empty_like(input_tensor) for _ in range(world_size)]
dist.all_gather(gathered_tensors, input_tensor, group=group)
handles = []
handle_type = type(handle)
for tensor in gathered_tensors:
bytes_list = tensor.cpu().tolist()
bytes_data = bytes(bytes_list)
handle_obj = handle_type()
ctypes.memmove(ctypes.addressof(handle_obj), bytes_data, len(bytes_data))
handles.append(handle_obj)
pointers: List[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer.value) # type: ignore
pointers.append(pointer.value)
else:
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
try:
opened_ptr = lib.cudaIpcOpenMemHandle(h)
pointers.append(opened_ptr.value)
except Exception as e:
print(f"Rank {rank}: Failed to open IPC handle from rank {i}: {e}")
raise
dist.barrier(group=group)
return pointers
@staticmethod
def free_shared_buffer(
pointers: List[int], group: Optional[ProcessGroup] = None
) -> None:
if group is None:
group = dist.group.WORLD
rank = dist.get_rank(group=group)
lib = CudaRTLibrary()
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
if pointers and len(pointers) > rank and pointers[rank] is not None:
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
dist.barrier(group=group)
def test_correctness(self):
for world_size in self.world_sizes:
if world_size > torch.cuda.device_count():
available_gpus = torch.cuda.device_count()
if world_size > available_gpus:
print(
f"Skipping world_size={world_size}, requires {world_size} GPUs, found {available_gpus}"
)
continue
multi_process_parallel(world_size, self.correctness)
print(f"custom allreduce tp = {world_size}: OK")
def init_custom_allreduce(self, rank, world_size, group):
buffer_max_size = 8 * 1024 * 1024
barrier_max_size = 8 * (24 + 2) * 8
self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group)
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
buffer_max_size, group=group
)
self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
)
self.custom_ptr = custom_ops.init_custom_reduce(
rank,
world_size,
self.rank_data,
self.buffer_ptrs,
self.tmp_result_buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
)
def custom_allreduce(self, inp, out):
custom_ops.custom_reduce(self.custom_ptr, inp, out)
def free_custom_allreduce(self, group):
self.free_shared_buffer(self.buffer_ptrs, group)
self.free_shared_buffer(self.tmp_result_buffer_ptrs, group)
self.free_shared_buffer(self.barrier_in_ptrs, group)
self.free_shared_buffer(self.barrier_out_ptrs, group)
custom_ops.custom_dispose(self.custom_ptr)
@staticmethod
def init_distributed_env(world_size, rank, distributed_init_port):
device = torch.device("cuda:0")
torch.cuda.set_device(device)
ranks = [i for i in range(world_size)]
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = torch.distributed.new_group(ranks, backend="gloo")
return group
# compare result with torch.distributed
def correctness(self, world_size, rank, distributed_init_port):
group = self.init_distributed_env(world_size, rank, distributed_init_port)
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
test_loop = 10
for sz in self.test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(test_loop):
inp1 = torch.randint(
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
)
out1 = torch.empty_like(inp1)
self.custom_allreduce(inp1, out1)
dist.all_reduce(inp1, group=group)
torch.testing.assert_close(out1, inp1)
self.free_custom_allreduce(group)
print(f"Running test for world_size={world_size}")
multi_process_parallel(
world_size, _run_correctness_worker, target_args=(self.test_sizes,)
)
print(f"custom allreduce tp = {world_size}: OK")
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment