"vllm/model_executor/models/lfm2_moe.py" did not exist on "de533ab2a14192e461900a4950e2b426d99a6862"
test_layernorm.py 5.01 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import pytest
4
5
import torch

6
from tests.kernels.quant_utils import FP8_DTYPE
7
from tests.kernels.utils import opcheck
8
from vllm.model_executor.layers.layernorm import RMSNorm
9
from vllm.platforms import current_platform
10

11
12
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 4096]  # Arbitrary values for testing
13
HIDDEN_SIZES = [8, 768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192,
14
                8199]  # Arbitrary values for testing
15
ADD_RESIDUAL = [False, True]
16
SEEDS = [0]
17
18
19
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
20

21

22
23
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
24
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
25
26
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
27
@pytest.mark.parametrize("device", CUDA_DEVICES)
28
@torch.inference_mode()
29
def test_rms_norm(
30
31
    num_tokens: int,
    hidden_size: int,
32
    add_residual: bool,
33
    dtype: torch.dtype,
34
    seed: int,
35
    device: str,
36
) -> None:
37
    current_platform.seed_everything(seed)
38
39
    torch.set_default_device(device)
    layer = RMSNorm(hidden_size).to(dtype=dtype)
40
41
    layer.weight.data.normal_(mean=1.0, std=0.1)
    scale = 1 / (2 * hidden_size)
42
    x = torch.randn(num_tokens, hidden_size, dtype=dtype)
43
44
45
46
47
    x *= scale
    residual = torch.randn_like(x) * scale if add_residual else None

    # NOTE(woosuk): The reference implementation should be executed first
    # because the custom kernel is in-place.
48
    ref_out = layer.forward_native(x, residual)
49
50
51
52
53
    out = layer(x, residual)
    # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
    # numerical errors than other operators because they involve reductions.
    # Therefore, we use a larger tolerance.
    if add_residual:
54
55
        torch.testing.assert_close(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
        torch.testing.assert_close(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
56
    else:
57
        torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
58
59
60
61
62
63
64

    if residual is not None:
        opcheck(torch.ops._C.fused_add_rms_norm,
                (x, residual, layer.weight.data, layer.variance_epsilon))
    else:
        opcheck(torch.ops._C.rms_norm,
                (out, x, layer.weight.data, layer.variance_epsilon))
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_fused_rms_norm_quant(
    num_tokens: int,
    hidden_size: int,
    add_residual: bool,
    dtype: torch.dtype,
    quant_scale: float,
    seed: int,
    device: str,
) -> None:
    current_platform.seed_everything(seed)
    torch.set_default_device(device)

    weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
    scale = 1 / (2 * hidden_size)
    x = torch.randn(num_tokens, hidden_size, dtype=dtype)
    x *= scale
    if add_residual:
        residual = torch.randn_like(x) * scale
        residual_fused = residual.clone()
    else:
        residual = residual_fused = None

    out_norm = torch.empty_like(x)
    out_quant = torch.empty_like(x, dtype=FP8_DTYPE)
    out_quant_fused = torch.empty_like(out_quant)

    quant_scale_t = torch.tensor(quant_scale, dtype=torch.float32)

    if add_residual:
        torch.ops._C.fused_add_rms_norm_static_fp8_quant(
            out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)

        # Unfused kernel is in-place so it goes second
        # Also use a separate clone of x to avoid modifying the input
        x_unfused = x.clone()
        torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
        torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused,
                                             quant_scale_t)

        torch.cuda.synchronize()
        torch.testing.assert_close(residual_fused,
                                   residual,
                                   atol=1e-2,
                                   rtol=1e-2)

        opcheck(
            torch.ops._C.fused_add_rms_norm_static_fp8_quant,
            (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
    else:
        torch.ops._C.rms_norm_static_fp8_quant(out_quant_fused, x, weight,
                                               quant_scale_t, 1e-6)

        torch.ops._C.rms_norm(out_norm, x, weight, 1e-6)
        torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm,
                                             quant_scale_t)

        opcheck(torch.ops._C.rms_norm_static_fp8_quant,
                (out_quant_fused, x, weight, quant_scale_t, 1e-6))

    torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32),
                               out_quant.to(dtype=torch.float32),
                               atol=1e-3,
                               rtol=1e-3)