test_layernorm.py 4.92 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch

# This registers op implementations
import vllm.kernels  # noqa: F401
8
9
10
11
12
13
14
from tests.ir.ir_test_utils import (
    COMMON_HIDDEN_SIZES,
    NUM_TOKENS,
    assert_close,
    clone_args,
    supported_providers,
)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from tests.kernels.allclose_default import get_default_rtol
from vllm import ir
from vllm.platforms import current_platform

rms_norm_native = ir.ops.rms_norm.impls["native"].impl_fn


@pytest.mark.skipif(
    not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
    reason="Currently only kernels on CUDA, ROCm and XPU",
)
def test_rms_norm_registration():
    expected = {
        "native": True,
        "vllm_c": current_platform.is_cuda_alike(),
        "aiter": current_platform.is_rocm(),
        "oink": False,
        "xpu_kernels": current_platform.is_xpu(),
    }

    actual = {
        provider: impl.supported for provider, impl in ir.ops.rms_norm.impls.items()
    }

    assert actual == expected


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
43
44
@pytest.mark.parametrize("n_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", COMMON_HIDDEN_SIZES)
45
46
47
48
49
50
51
52
53
54
55
@pytest.mark.parametrize("epsilon", [1e-6, 1e-5])
@pytest.mark.skipif(
    not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
    reason="Currently only kernels on CUDA, ROCm and XPU",
)
class TestRMSNorm:
    @classmethod
    def setup_class(cls, **kwargs):
        torch.set_default_device(current_platform.device_type)

    def test_native_semantics(self, dtype, n_tokens, hidden_size, epsilon):
56
57
58
        x, weight, epsilon = ir.ops.rms_norm.generate_inputs(
            num_tokens=4, hidden_size=8, dtype=dtype, epsilon=epsilon
        )
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        out = rms_norm_native(x, weight, epsilon=epsilon)

        # Check shape, dtype, device
        assert out.shape == x.shape
        assert out.dtype == x.dtype
        assert out.device == x.device

        # Check the scaling property of rms norm
        out2 = rms_norm_native(x * 2.0, weight, epsilon=epsilon)
        torch.testing.assert_close(out2, out, rtol=get_default_rtol(out), atol=1e-3)

        # Check behavior with and without weight
        weight1 = torch.ones_like(weight)
        out3 = rms_norm_native(x, weight1, epsilon=epsilon)
        out4 = rms_norm_native(x, None, epsilon=epsilon)
        torch.testing.assert_close(out3, out4)

76
    @pytest.mark.parametrize("provider", supported_providers(ir.ops.rms_norm))
77
78
    def test_impls(self, dtype, n_tokens, hidden_size, epsilon, provider):
        impl = ir.ops.rms_norm.impls[provider]
79
80
81
82
        x, weight, eps = ir.ops.rms_norm.generate_inputs(
            num_tokens=n_tokens, hidden_size=hidden_size, dtype=dtype, epsilon=epsilon
        )
        args = (x, weight, eps)
83

84
85
        if not impl.supports_args(*args):
            pytest.skip(f"{provider} does not support args")
86

87
88
89
        ref_output = rms_norm_native(*clone_args(args))
        output = impl.impl_fn(*clone_args(args))
        assert_close(ir.ops.rms_norm, output, ref_output)
90
91
92

        # check that dispatched call matches direct call
        with ir.ops.rms_norm.set_priority([provider, "native"]):
93
94
95
            out_dispatched = ir.ops.rms_norm(*args)
        out_direct = impl.impl_fn(*args)
        torch.testing.assert_close(out_dispatched, out_direct, rtol=0.0, atol=0.0)
96
97

        # none of these support variance_size override
98
99
        assert not impl.supports_args(x, weight, eps, 4)
        assert not impl.supports_args(x, weight, eps, variance_size=4)
100
101

        # test weight=None behavior
102
103
104
        out_no_weight = impl.impl_fn(x, None, eps)
        out_unit_weight = impl.impl_fn(x, torch.ones_like(weight), eps)
        assert_close(ir.ops.rms_norm, out_no_weight, out_unit_weight)
105
106
107
108
109
110

    @pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels", "native"])
    def test_torch_opcheck(self, dtype, n_tokens, hidden_size, epsilon, provider):
        if not ir.ops.rms_norm.impls[provider].supported:
            pytest.skip(f"{provider} impl not supported on this platform")

111
112
113
        args = ir.ops.rms_norm.generate_inputs(
            num_tokens=n_tokens, hidden_size=hidden_size, dtype=dtype, epsilon=epsilon
        )
114
115
116
117

        # When checking the torch op, we have to set priority and use dispatch
        with ir.ops.rms_norm.set_priority([provider, "native"]):
            torch.library.opcheck(torch.ops.vllm_ir.rms_norm, args)
118
119
120
121
122
123
124
125
126
127
128
129
130
131


@pytest.mark.skipif(
    not current_platform.is_rocm(),
    reason="aiter is only supported on ROCm",
)
def test_aiter_rejects_unsupported_dtypes():
    torch.set_default_device(current_platform.device_type)
    impl = ir.ops.rms_norm.impls["aiter"]
    for dtype in [torch.float32, torch.float64]:
        args = ir.ops.rms_norm.generate_inputs(
            num_tokens=8, hidden_size=4096, dtype=dtype, epsilon=1e-5
        )
        assert not impl.supports_args(*args), f"aiter should reject dtype={dtype}"