test_triton_scaled_mm.py 3.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import itertools
import unittest
from typing import Optional

import torch
import torch.testing

from sglang.srt.layers.quantization.fp8_kernel import triton_scaled_mm
from sglang.test.test_utils import CustomTestCase


def torch_scaled_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: torch.dtype,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Reference implementation using float32 for stability"""
    out = torch.mm(a.to(torch.float32), b.to(torch.float32))
    out = scale_a.to(torch.float32) * out * scale_b.to(torch.float32).T
    if bias is not None:
        out = out + bias.to(torch.float32)
    return out.to(out_dtype)


class TestScaledMM(CustomTestCase):
    @classmethod
    def setUpClass(cls):
        if not torch.cuda.is_available():
            raise unittest.SkipTest("This test requires a CUDA device.")
        torch.set_default_device("cuda")

    def _make_inputs(self, M, K, N, in_dtype):
        if in_dtype == torch.int8:
            a = torch.randint(-8, 8, (M, K), dtype=in_dtype, device="cuda")
            b = torch.randint(-8, 8, (K, N), dtype=in_dtype, device="cuda")
        else:  # fp8
            a = torch.clamp(
                0.1 * torch.randn((M, K), dtype=torch.float16, device="cuda"), -0.3, 0.3
            ).to(in_dtype)
            b = torch.clamp(
                0.1 * torch.randn((K, N), dtype=torch.float16, device="cuda"), -0.3, 0.3
            ).to(in_dtype)
        return a, b

    def test_basic_cases(self):
        """Test core functionality with reduced precision requirements"""
        test_configs = [
            (32, 32, 32, torch.int8, torch.float16, False),
            (64, 64, 64, torch.int8, torch.float16, True),
        ]

        try:
            torch.tensor([1.0], dtype=torch.float8_e4m3fn, device="cuda")
            test_configs.append((32, 32, 32, torch.float8_e4m3fn, torch.float16, False))
        except:
            print("FP8 not supported, skipping")

        for M, K, N, in_dtype, out_dtype, with_bias in test_configs:
            with self.subTest(M=M, K=K, N=N, dtype=in_dtype, bias=with_bias):
                print(f"Currently testing with in_dtype: {in_dtype}")
                torch.manual_seed(42)

                input, weight = self._make_inputs(M, K, N, in_dtype)
                scale_a = 0.1 + 0.05 * torch.rand(
                    (M, 1), dtype=torch.float32, device="cuda"
                )
                scale_b = 0.1 + 0.05 * torch.rand(
                    (N, 1), dtype=torch.float32, device="cuda"
                )
                bias = (
                    0.01 * torch.randn((M, N), dtype=out_dtype, device="cuda")
                    if with_bias
                    else None
                )

                triton_out = triton_scaled_mm(
                    input, weight, scale_a, scale_b, out_dtype, bias
                )
                ref_out = torch_scaled_mm(
                    input, weight, scale_a, scale_b, out_dtype, bias
                )

                # Use relaxed tolerances
                rtol = 0.15 if in_dtype == torch.int8 else 0.25
                atol = 0.1 if in_dtype == torch.int8 else 0.15

                torch.testing.assert_close(triton_out, ref_out, rtol=rtol, atol=atol)


if __name__ == "__main__":
    unittest.main(verbosity=2)