import pytest import torch from sgl_kernel import rmsnorm def llama_rms_norm(x, w, eps=1e-6): orig_dtype = x.dtype x = x.float() variance = x.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + eps) x = x * w.float() x = x.to(orig_dtype) return x @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("specify_out", [True, False]) def test_norm(batch_size, hidden_size, dtype, specify_out): x = torch.randn(batch_size, hidden_size).to(0).to(dtype) w = torch.randn(hidden_size).to(0).to(dtype) y_ref = llama_rms_norm(x, w) if specify_out: y = torch.empty_like(x) rmsnorm(x, w, out=y) else: y = rmsnorm(x, w) torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)