test_norm.py 2.76 KB
Newer Older
1
2
3
4
5
6
import itertools
import unittest
from typing import Optional, Tuple, Union

import sgl_kernel
import torch
7
from utils import make_non_contiguous, precision
8
9
10

from sglang.test.test_utils import CustomTestCase

11
12
torch.manual_seed(0)

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

class TestNorm(CustomTestCase):
    M = [4096, 1024]
    N = [4096, 4096 + 13]
    dtype = [torch.float16, torch.bfloat16]

    def _forward_native(
        self,
        x: torch.Tensor,
        weight: torch.Tensor,
        variance_epsilon: float = 1e-6,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x.to(orig_dtype) * weight
        if residual is None:
            return x
        else:
            return x, residual

    def _norm_test(self, m, n, dtype):

        x = torch.randn([m, n], dtype=dtype)
43
        x = make_non_contiguous(x)
44
45
46
47
48
49
50
51
        hidden_size = x.size(-1)
        weight = torch.randn(hidden_size, dtype=dtype)
        variance_epsilon = 1e-6

        out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon)
        ref_out = self._forward_native(x, weight, variance_epsilon)

        atol = rtol = precision[ref_out.dtype]
52
        torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
53
54

        ref_x = x.clone()
55
        residual = torch.randn([m, hidden_size], dtype=dtype)
56
57
58
59
60
61
62
63
64
65
        ref_residual = residual.clone()

        torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
            x, residual, weight, variance_epsilon
        )

        ref_x, ref_residual = self._forward_native(
            ref_x, weight, variance_epsilon, ref_residual
        )

66
67
        torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol)
        torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol)
68

69
70
71
72
73
74
75
76
77
78
79
    def _l2norm_test(self, m, n, dtype):

        x = torch.randn([m, n], dtype=dtype)
        hidden_size = x.size(-1)
        fake_ones_weight = torch.ones(hidden_size, dtype=dtype)
        variance_epsilon = 1e-6

        out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon)
        ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)

        atol = rtol = precision[ref_out.dtype]
80
        torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
81

82
83
84
85
    def test_norm(self):
        for params in itertools.product(self.M, self.N, self.dtype):
            with self.subTest(m=params[0], n=params[1], dtype=params[2]):
                self._norm_test(*params)
86
                self._l2norm_test(*params)
87
88
89
90


if __name__ == "__main__":
    unittest.main()