"vscode:/vscode.git/clone" did not exist on "aa78aeaa0f4fee0ccb0184914266d6170cfc848f"
test_fused_quant_layernorm.py 5.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from typing import Optional, Tuple, Union

import pytest
import torch

import vllm._custom_ops as ops
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm

DTYPES = [torch.bfloat16, torch.float]
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
VEC_HIDDEN_SIZES = range(1024, 1030)
# Avoid combinatorial explosion with full Cartesian product
NUM_TOKENS_HIDDEN_SIZES = [
    *[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]],
    *[(83, i) for i in [1, 1033, 2048, 5120]],
    *[(2048, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5137]],
    *[(4096, i) for i in [1, 64, 5137]],
]

ADD_RESIDUAL = [False, True]
SCALE_UBS = [True, False]
SEEDS = [0]
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

EPS = 1e-6

## Helpers


def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
    return torch.as_tensor(x, dtype=torch.float32, device='cuda')


def ref_rms_norm(rms_norm_layer: RMSNorm,
                 x: torch.Tensor,
                 residual: Optional[torch.Tensor]) \
        -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    if residual is not None:
        residual = residual.clone()
        out, residual = rms_norm_layer.forward_native(x, residual)
    else:
        out = rms_norm_layer.forward_native(x)

    return out, residual


def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
                                x: torch.Tensor,
                                quant_dtype: torch.dtype,
                                residual: Optional[torch.Tensor],
                                scale_ub: Optional[torch.Tensor]) \
        -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    if scale_ub is not None:
        assert quant_dtype == torch.float8_e4m3fn

    # Norm
    torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual)

    # Quant
    if quant_dtype == torch.float8_e4m3fn:
        torch_out, scales = ops.scaled_fp8_quant(torch_out,
                                                 scale_ub=scale_ub,
                                                 use_per_token_if_dynamic=True)
    else:
        assert quant_dtype == torch.int8
        torch_out, scales = ops.scaled_int8_quant(torch_out)

    return torch_out, scales, residual


def ref_impl(rms_norm_layer: RMSNorm,
             x: torch.Tensor,
             quant_dtype: torch.dtype,
             residual: Optional[torch.Tensor],
             scale_ub: Optional[torch.Tensor]) \
        -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype,
                                       residual, scale_ub)


def ops_dynamic_per_token_quant(weight: torch.Tensor,
                                x: torch.Tensor,
                                quant_dtype: torch.dtype,
                                residual: Optional[torch.Tensor],
                                scale_ub: Optional[torch.Tensor]) \
        -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    if residual is not None:
        residual = residual.clone()
    out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS,
                                                       quant_dtype, scale_ub,
                                                       residual)
    return out, scales, residual


def ops_impl(weight: torch.Tensor,
             x: torch.Tensor,
             quant_dtype: torch.dtype,
             residual: Optional[torch.Tensor],
             scale_ub: Optional[torch.Tensor]) \
        -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual,
                                       scale_ub)


@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_rms_norm(
    num_tokens: int,
    hidden_size: int,
    add_residual: bool,
    scale_ub: bool,
    dtype: torch.dtype,
    quant_dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    torch.random.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)

    if scale_ub is not None and quant_dtype != torch.float8_e4m3fn:
        # skip
        return

    layer = RMSNorm(hidden_size, EPS).to(dtype=dtype)

    # Make weights
    layer.weight.data.normal_(mean=1.0, std=0.1)

    # Make inputs
    scale = 1 / (hidden_size)
    x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
    residual = torch.randn_like(x) * scale if add_residual else None
    if scale_ub is not None:
        rms_x, _ = ref_rms_norm(layer, x, residual)
        scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device='cuda')

    ref_out, ref_scales, ref_residual = \
        ref_impl(layer, x, quant_dtype, residual, scale_ub)
    ops_out, ops_scales, ops_residual = \
        ops_impl(layer.weight, x, quant_dtype, residual, scale_ub)

    assert ref_out.dtype == quant_dtype
    assert ops_out.dtype == quant_dtype
    assert torch.allclose(ref_scales, ops_scales)
    if quant_dtype == torch.int8:
        # big atol to account for round-off errors.
        assert torch.allclose(ref_out, ops_out, atol=1)
    else:
        assert torch.allclose(ref_out.to(dtype=torch.float32),
                              ops_out.to(dtype=torch.float32))
    if add_residual:
        assert torch.allclose(ref_residual, ops_residual)

    output = torch.empty_like(x, dtype=quant_dtype)
    scales = torch.empty((x.numel() // x.shape[-1], 1),
                         device=x.device,
                         dtype=torch.float32)

    opcheck(torch.ops._C.rms_norm_dynamic_per_token_quant,
            (output, x, layer.weight, scales, 1e-5, scale_ub, residual))