Unverified Commit 9fccda31 authored by Adarsh Shirawalmath's avatar Adarsh Shirawalmath Committed by GitHub
Browse files

[Feature] use pytest for sgl-kernel (#4896)

parent 4ede6770
...@@ -80,7 +80,8 @@ jobs: ...@@ -80,7 +80,8 @@ jobs:
- name: Install - name: Install
run: | run: |
pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.6.4.post1 bash scripts/ci_install_dependency.sh
pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.7.2
pip3 uninstall sgl-kernel -y || true pip3 uninstall sgl-kernel -y || true
pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps
pip3 list | grep sgl-kernel pip3 list | grep sgl-kernel
...@@ -89,7 +90,7 @@ jobs: ...@@ -89,7 +90,7 @@ jobs:
timeout-minutes: 30 timeout-minutes: 30
run: | run: |
cd sgl-kernel cd sgl-kernel
find tests -name "test_*.py" | xargs -n 1 python3 pytest tests/
- name: Uninstall dependencies - name: Uninstall dependencies
run: | run: |
......
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sgl_kernel import verify_tree_greedy from sgl_kernel import verify_tree_greedy
...@@ -85,14 +86,14 @@ def test_verify_tree_greedy(): ...@@ -85,14 +86,14 @@ def test_verify_tree_greedy():
print(f"{accept_index=}") print(f"{accept_index=}")
print(f"{accept_token_num=}") print(f"{accept_token_num=}")
return predicts, accept_index, accept_token_num # Check the expected output.
if __name__ == "__main__":
predicts, accept_index, accept_token_num = test_verify_tree_greedy()
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [ assert accept_index.tolist() == [
[0, 3, 4, 5], [0, 3, 4, 5],
[6, 10, 11, -1], [6, 10, 11, -1],
] ]
assert accept_token_num.tolist() == [3, 2] assert accept_token_num.tolist() == [3, 2]
if __name__ == "__main__":
pytest.main([__file__])
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sgl_kernel import tree_speculative_sampling_target_only from sgl_kernel import tree_speculative_sampling_target_only
...@@ -97,26 +98,21 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc ...@@ -97,26 +98,21 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
print(f"{accept_index=}") print(f"{accept_index=}")
print(f"{accept_token_num=}") print(f"{accept_token_num=}")
return predicts, accept_index, accept_token_num if threshold_single == 1 and threshold_acc == 1:
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 3, 4, 5],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [3, 2]
elif threshold_single == 0 and threshold_acc == 0:
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 1, 2, -1],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [2, 2]
if __name__ == "__main__": if __name__ == "__main__":
predicts, accept_index, accept_token_num = ( pytest.main([__file__])
test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1)
)
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 3, 4, 5],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [3, 2]
predicts, accept_index, accept_token_num = (
test_tree_speculative_sampling_target_only(threshold_single=0, threshold_acc=0)
)
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 1, 2, -1],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [2, 2]
...@@ -128,5 +128,4 @@ def test_awq_dequant_compare_implementations( ...@@ -128,5 +128,4 @@ def test_awq_dequant_compare_implementations(
if __name__ == "__main__": if __name__ == "__main__":
# Run the specific test function directly
pytest.main([__file__]) pytest.main([__file__])
import unittest import pytest
import torch import torch
from sgl_kernel import cublas_grouped_gemm from sgl_kernel import cublas_grouped_gemm
def torch_grouped_gemm(a_array, b_array, out_dtype): def torch_grouped_gemm(a_array, b_array, out_dtype):
c_array = [] return [torch.matmul(a, b.t()).to(out_dtype) for a, b in zip(a_array, b_array)]
for a, b in zip(a_array, b_array):
c_array.append(torch.matmul(a, b.t()).to(out_dtype))
return c_array skip_condition = not torch.cuda.is_available() or (
torch.version.cuda is None
or tuple(map(int, torch.version.cuda.split("."))) < (12, 5)
class TestGroupedGemm(unittest.TestCase): )
def _test_accuracy(self, Ms, Ns, Ks, out_dtype):
group_count = len(Ms)
a_array = [] @pytest.mark.skipif(
b_array = [] skip_condition, reason="CUDA not available or CUDA version lower than 12.5"
c_array_cublas = [] )
for i in range(group_count): @pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
M, N, K = Ms[i], Ns[i], Ks[i] @pytest.mark.parametrize("M", [1, 16, 32, 256, 1024])
a_array.append(torch.randn((M, K), device="cuda", dtype=out_dtype) * 5) @pytest.mark.parametrize("N", [2, 16, 128, 256, 4096])
b_array.append(torch.randn((N, K), device="cuda", dtype=out_dtype) * 5) @pytest.mark.parametrize("K", [3, 16, 32, 512, 8192])
c_array_cublas.append(torch.empty((M, N), device="cuda", dtype=out_dtype)) def test_grouped_gemm_accuracy(out_dtype, M, N, K):
a = torch.randn((M, K), device="cuda", dtype=out_dtype) * 5
c_array_torch = torch_grouped_gemm(a_array, b_array, out_dtype) b = torch.randn((N, K), device="cuda", dtype=out_dtype) * 5
cublas_grouped_gemm(a_array, b_array, c_array_cublas, out_dtype) expected = torch.matmul(a, b.t()).to(out_dtype)
for i in range(group_count): a_array = [a]
M, N, K = Ms[i], Ns[i], Ks[i] b_array = [b]
torch.testing.assert_close(c_array_torch[i], c_array_cublas[i]) c_array = [torch.empty((M, N), device="cuda", dtype=out_dtype)]
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
result_torch = torch_grouped_gemm(a_array, b_array, out_dtype)[0]
def test_accuracy(self): cublas_grouped_gemm(a_array, b_array, c_array, out_dtype)
Ms = [1, 16, 32, 256, 1024]
Ns = [2, 16, 128, 256, 4096] torch.testing.assert_close(result_torch, expected)
Ks = [3, 16, 32, 512, 8192] torch.testing.assert_close(c_array[0], expected)
out_dtypes = [torch.float16, torch.bfloat16]
for out_dtype in out_dtypes:
self._test_accuracy(Ms, Ns, Ks, out_dtype)
if __name__ == "__main__": if __name__ == "__main__":
if torch.cuda.is_available(): pytest.main([__file__])
cuda_version = tuple(map(int, torch.version.cuda.split(".")))
if cuda_version >= (12, 5):
unittest.main()
else:
print(f"Cuda version {cuda_version} lower than 12.5, not executing tests.")
import unittest import os
import random
from typing import Optional, Type from typing import Optional, Type
import pytest
import torch import torch
from sgl_kernel import fp8_blockwise_scaled_mm from sgl_kernel import fp8_blockwise_scaled_mm
def cdiv(a: int, b: int) -> int: def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b) return -(a // -b)
...@@ -23,7 +24,6 @@ def baseline_scaled_mm( ...@@ -23,7 +24,6 @@ def baseline_scaled_mm(
out_dtype: Type[torch.dtype], out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# We treat N-dimensional group scaling as extended numpy-style broadcasting # We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match the # in numpy simply stretches dimensions with an extent of 1 to match the
# the target shape by repeating the data along that dimension (broadcasting) # the target shape by repeating the data along that dimension (broadcasting)
...@@ -51,62 +51,44 @@ def baseline_scaled_mm( ...@@ -51,62 +51,44 @@ def baseline_scaled_mm(
scale_a = group_broadcast(scale_a, a.shape) scale_a = group_broadcast(scale_a, a.shape)
scale_b = group_broadcast(scale_b, b.shape) scale_b = group_broadcast(scale_b, b.shape)
output = torch.mm( output = torch.mm(
(scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32)) (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
).to(out_dtype) ).to(out_dtype)
if bias is not None: if bias is not None:
output = output + bias output = output + bias
return output return output
class TestFp8Gemm(unittest.TestCase): def _test_accuracy_once(M, N, K, out_dtype, device):
def _test_accuracy_once(self, M, N, K, out_dtype, device): fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min
fp8_max, fp8_min = fp8_info.max, fp8_info.min a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
a_fp32 = ( b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
) scale_a_group_shape = (1, 128)
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) scale_b_group_shape = (128, 128)
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
b_fp32 = ( scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001
) scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t() scale_a = scale_a.t().contiguous().t()
scale_b = scale_b.t().contiguous().t()
scale_a_group_shape = (1, 128) o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
scale_b_group_shape = (128, 128) o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape) rtol = 0.02
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape) atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001 print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
scale_a = scale_a.t().contiguous().t()
scale_b = scale_b.t().contiguous().t() @pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096])
@pytest.mark.parametrize("N", [128, 512, 1024, 4096])
o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16) @pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype) def test_accuracy(M, N, K, out_dtype):
_test_accuracy_once(M, N, K, out_dtype, "cuda")
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
def test_accuracy(self):
Ms = [1, 128, 512, 1024, 4096]
Ns = [128, 512, 1024, 4096]
Ks = [512, 1024, 4096, 8192, 16384]
out_dtypes = [torch.bfloat16, torch.float16]
for M in Ms:
for N in Ns:
for K in Ks:
for out_dtype in out_dtypes:
self._test_accuracy_once(M, N, K, out_dtype, "cuda")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() pytest.main([__file__])
import unittest import pytest
import torch import torch
from sgl_kernel import fp8_scaled_mm from sgl_kernel import fp8_scaled_mm
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
o = o.to(torch.float32) o = o.to(torch.float32)
temp1 = o * scale_a.view(-1, 1) temp1 = o * scale_a.view(-1, 1)
temp2 = temp1 * scale_b.view(1, -1) temp2 = temp1 * scale_b.view(1, -1)
final = temp2.to(out_dtype) final = temp2.to(out_dtype)
if bias is not None: if bias is not None:
final = final + bias.view(1, -1) final = final + bias.view(1, -1)
return final return final
class TestFp8Gemm(unittest.TestCase): def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min
fp8_max, fp8_min = fp8_info.max, fp8_info.min a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
a_fp32 = ( b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
) scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001
if with_bias:
b_fp32 = ( bias = torch.randn((N,), device=device, dtype=out_dtype)
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max else:
) bias = None
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) b_fp8 = b_fp8.t()
o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001 o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001 rtol = 0.02
if with_bias: atol = 1
bias = torch.randn((N,), device=device, dtype=out_dtype) torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
else: print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
bias = None
o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
b_fp8 = b_fp8.t() @pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096])
o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) @pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096])
o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) @pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
rtol = 0.02 @pytest.mark.parametrize("with_bias", [True, False])
atol = 1 @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) def test_accuracy(M, N, K, with_bias, out_dtype):
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") _test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda")
def test_accuracy(self):
Ms = [1, 128, 512, 1024, 4096]
Ns = [16, 128, 512, 1024, 4096]
Ks = [512, 1024, 4096, 8192, 16384]
bias_opts = [True, False]
out_dtypes = [torch.bfloat16, torch.float16]
for M in Ms:
for N in Ns:
for K in Ks:
for with_bias in bias_opts:
for out_dtype in out_dtypes:
self._test_accuracy_once(
M, N, K, with_bias, out_dtype, "cuda"
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() pytest.main([__file__])
import unittest import pytest
import torch import torch
from sgl_kernel import int8_scaled_mm from sgl_kernel import int8_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
...@@ -18,39 +17,31 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): ...@@ -18,39 +17,31 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
return o.to(out_dtype) return o.to(out_dtype)
class TestInt8Gemm(unittest.TestCase): def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): a = to_int8(torch.randn((M, K), device=device) * 5)
a = to_int8(torch.randn((M, K), device=device) * 5) b = to_int8(torch.randn((N, K), device=device).t() * 5)
b = to_int8(torch.randn((N, K), device=device).t() * 5) scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) if with_bias:
if with_bias: bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10
bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10 else:
else: bias = None
bias = None o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) torch.testing.assert_close(o, o1)
o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) torch.testing.assert_close(o, o2)
torch.testing.assert_close(o, o1) print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
torch.testing.assert_close(o, o2)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
def test_accuracy(self): @pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096, 8192, 16384])
Ms = [1, 16, 32, 64, 128, 512, 1024, 4096, 8192] @pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
Ns = [16, 128, 512, 1024, 4096, 8192, 16384] @pytest.mark.parametrize("with_bias", [True, False])
Ks = [512, 1024, 4096, 8192, 16384] @pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
bias_opts = [True, False] def test_accuracy(M, N, K, with_bias, out_dtype):
out_dtypes = [torch.float16, torch.bfloat16] _test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda")
for M in Ms:
for N in Ns:
for K in Ks:
for with_bias in bias_opts:
for out_dtype in out_dtypes:
self._test_accuracy_once(
M, N, K, with_bias, out_dtype, "cuda"
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() pytest.main([__file__])
...@@ -51,5 +51,4 @@ def test_per_token_quant_compare_implementations( ...@@ -51,5 +51,4 @@ def test_per_token_quant_compare_implementations(
if __name__ == "__main__": if __name__ == "__main__":
# Run the specific test function directly
pytest.main([__file__]) pytest.main([__file__])
...@@ -13,154 +13,185 @@ from torch.distributed import ProcessGroup ...@@ -13,154 +13,185 @@ from torch.distributed import ProcessGroup
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
ranks = list(range(world_size))
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = dist.group.WORLD
buffer_max_size = 8 * 1024 * 1024
barrier_max_size = 8 * (24 + 2) * 8
buffer_ptrs = None
tmp_result_buffer_ptrs = None
barrier_in_ptrs = None
barrier_out_ptrs = None
custom_ptr = None
try:
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
buffer_max_size, group=group
)
tmp_result_buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
buffer_max_size, group=group
)
barrier_in_ptrs = TestCustomAllReduce.create_shared_buffer(
barrier_max_size, group=group
)
barrier_out_ptrs = TestCustomAllReduce.create_shared_buffer(
barrier_max_size, group=group
)
rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)
custom_ptr = custom_ops.init_custom_reduce(
rank,
world_size,
rank_data,
buffer_ptrs,
tmp_result_buffer_ptrs,
barrier_in_ptrs,
barrier_out_ptrs,
)
test_loop = 10
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(test_loop):
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
inp1_ref = inp1.clone()
out1 = torch.empty_like(inp1)
custom_ops.custom_reduce(custom_ptr, inp1, out1)
dist.all_reduce(inp1_ref, group=group)
torch.testing.assert_close(out1, inp1_ref)
finally:
dist.barrier(group=group)
if custom_ptr is not None:
custom_ops.custom_dispose(custom_ptr)
if buffer_ptrs:
TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group)
if tmp_result_buffer_ptrs:
TestCustomAllReduce.free_shared_buffer(tmp_result_buffer_ptrs, group)
if barrier_in_ptrs:
TestCustomAllReduce.free_shared_buffer(barrier_in_ptrs, group)
if barrier_out_ptrs:
TestCustomAllReduce.free_shared_buffer(barrier_out_ptrs, group)
dist.destroy_process_group(group=group)
def get_open_port() -> int: def get_open_port() -> int:
# try ipv4
try: try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0)) s.bind(("127.0.0.1", 0))
return s.getsockname()[1] return s.getsockname()[1]
except OSError: except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0)) s.bind(("::1", 0))
return s.getsockname()[1] return s.getsockname()[1]
def multi_process_parallel( def multi_process_parallel(
world_size: int, world_size: int, test_target: Any, target_args: tuple = ()
test_target: Any,
) -> None: ) -> None:
mp.set_start_method("spawn", force=True)
procs = [] procs = []
distributed_init_port = get_open_port() distributed_init_port = get_open_port()
for i in range(world_size): for i in range(world_size):
proc = mp.Process( proc_args = (world_size, i, distributed_init_port) + target_args
target=test_target, proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
args=(world_size, i, distributed_init_port),
)
proc.start() proc.start()
procs.append(proc) procs.append(proc)
for i in range(world_size): for i in range(world_size):
procs[i].join() procs[i].join()
assert procs[i].exitcode == 0 assert (
procs[i].exitcode == 0
), f"Process {i} failed with exit code {procs[i].exitcode}"
class TestCustomAllReduce(unittest.TestCase): class TestCustomAllReduce(unittest.TestCase):
@classmethod test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
def setUpClass(cls): world_sizes = [2, 4, 8]
random.seed(42)
cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
cls.world_sizes = [2, 4, 8]
@staticmethod @staticmethod
def create_shared_buffer( def create_shared_buffer(
size_in_bytes: int, group: Optional[ProcessGroup] = None size_in_bytes: int, group: Optional[ProcessGroup] = None
) -> List[int]: ) -> List[int]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib = CudaRTLibrary() lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes) pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer) handle = lib.cudaIpcGetMemHandle(pointer)
if group is None:
group = dist.group.WORLD
world_size = dist.get_world_size(group=group) world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group) rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group) handle_bytes = ctypes.string_at(ctypes.addressof(handle), ctypes.sizeof(handle))
input_tensor = torch.ByteTensor(list(handle_bytes)).to(f"cuda:{rank}")
gathered_tensors = [torch.empty_like(input_tensor) for _ in range(world_size)]
dist.all_gather(gathered_tensors, input_tensor, group=group)
handles = []
handle_type = type(handle)
for tensor in gathered_tensors:
bytes_list = tensor.cpu().tolist()
bytes_data = bytes(bytes_list)
handle_obj = handle_type()
ctypes.memmove(ctypes.addressof(handle_obj), bytes_data, len(bytes_data))
handles.append(handle_obj)
pointers: List[int] = [] pointers: List[int] = []
for i, h in enumerate(handles): for i, h in enumerate(handles):
if i == rank: if i == rank:
pointers.append(pointer.value) # type: ignore pointers.append(pointer.value)
else: else:
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore try:
opened_ptr = lib.cudaIpcOpenMemHandle(h)
pointers.append(opened_ptr.value)
except Exception as e:
print(f"Rank {rank}: Failed to open IPC handle from rank {i}: {e}")
raise
dist.barrier(group=group)
return pointers return pointers
@staticmethod @staticmethod
def free_shared_buffer( def free_shared_buffer(
pointers: List[int], group: Optional[ProcessGroup] = None pointers: List[int], group: Optional[ProcessGroup] = None
) -> None: ) -> None:
if group is None:
group = dist.group.WORLD
rank = dist.get_rank(group=group) rank = dist.get_rank(group=group)
lib = CudaRTLibrary() lib = CudaRTLibrary()
lib.cudaFree(ctypes.c_void_p(pointers[rank])) if pointers and len(pointers) > rank and pointers[rank] is not None:
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
dist.barrier(group=group)
def test_correctness(self): def test_correctness(self):
for world_size in self.world_sizes: for world_size in self.world_sizes:
if world_size > torch.cuda.device_count(): available_gpus = torch.cuda.device_count()
if world_size > available_gpus:
print(
f"Skipping world_size={world_size}, requires {world_size} GPUs, found {available_gpus}"
)
continue continue
multi_process_parallel(world_size, self.correctness)
print(f"custom allreduce tp = {world_size}: OK")
def init_custom_allreduce(self, rank, world_size, group):
buffer_max_size = 8 * 1024 * 1024
barrier_max_size = 8 * (24 + 2) * 8
self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group)
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
buffer_max_size, group=group
)
self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
)
self.custom_ptr = custom_ops.init_custom_reduce(
rank,
world_size,
self.rank_data,
self.buffer_ptrs,
self.tmp_result_buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
)
def custom_allreduce(self, inp, out):
custom_ops.custom_reduce(self.custom_ptr, inp, out)
def free_custom_allreduce(self, group): print(f"Running test for world_size={world_size}")
self.free_shared_buffer(self.buffer_ptrs, group) multi_process_parallel(
self.free_shared_buffer(self.tmp_result_buffer_ptrs, group) world_size, _run_correctness_worker, target_args=(self.test_sizes,)
self.free_shared_buffer(self.barrier_in_ptrs, group) )
self.free_shared_buffer(self.barrier_out_ptrs, group) print(f"custom allreduce tp = {world_size}: OK")
custom_ops.custom_dispose(self.custom_ptr)
@staticmethod
def init_distributed_env(world_size, rank, distributed_init_port):
device = torch.device("cuda:0")
torch.cuda.set_device(device)
ranks = [i for i in range(world_size)]
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = torch.distributed.new_group(ranks, backend="gloo")
return group
# compare result with torch.distributed
def correctness(self, world_size, rank, distributed_init_port):
group = self.init_distributed_env(world_size, rank, distributed_init_port)
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
test_loop = 10
for sz in self.test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(test_loop):
inp1 = torch.randint(
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
)
out1 = torch.empty_like(inp1)
self.custom_allreduce(inp1, out1)
dist.all_reduce(inp1, group=group)
torch.testing.assert_close(out1, inp1)
self.free_custom_allreduce(group)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment