test_rms_norm.py 9.58 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import os
import tempfile

import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn

from test.utils import assert_verbose_allclose
from test.utils import set_seed
from test.utils import supports_bfloat16

from liger_kernel.ops import LigerRMSNormFunction
from liger_kernel.transformers.functional import liger_rms_norm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.utils import infer_comm_backend
from liger_kernel.utils import infer_device

device = infer_device()

set_seed(42)
torch.use_deterministic_algorithms(True)

#  Only setting torch.use_deterministic_algorithms(True) might throw the following error:
#  RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`,
#  but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an
#  environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information,
#  go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility

if device == "cuda":
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

SLEEP_SECONDS = 0.1


class BaseRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6, elementwise_affine=True):
        super().__init__()
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.ones(hidden_size))
        else:
            self.register_parameter("weight", None)
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        if self.elementwise_affine:
            return self.weight * hidden_states.to(input_dtype)
        else:
            return hidden_states.to(input_dtype)


# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L112
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6, elementwise_affine=True):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.ones(hidden_size))
        else:
            self.register_parameter("weight", None)
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        if self.elementwise_affine:
            return self.weight * hidden_states.to(input_dtype)
        else:
            return hidden_states.to(input_dtype)


# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L122
class GemmaRMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6, elementwise_affine=True):
        super().__init__()
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(hidden_size))
        else:
            self.register_parameter("weight", None)

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float())
        if self.elementwise_affine:
            output = output * (1.0 + self.weight.float())
        return output.type_as(x)


@pytest.mark.flaky(reruns=3, reruns_delay=2)
@pytest.mark.parametrize(
    "bs, sl, hd",
    [
        (2, 128, 512),
        # weird shapes
        (5, 123, 123),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        (torch.float32, 1e-4, 1e-6),
        pytest.param(
            torch.bfloat16,
            2e-1,
            2e-2,
            marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
        ),
    ],
)
@pytest.mark.parametrize(
    "reference, offset, casting_mode",
    [
        (LlamaRMSNorm, 0.0, "llama"),
        (GemmaRMSNorm, 1.0, "gemma"),
        pytest.param(
            BaseRMSNorm,
            0.0,
            "none",
            marks=pytest.mark.skipif(device == "npu", reason="Ascend NPU does not support this test"),
        ),
    ],
)
@pytest.mark.parametrize(
    "in_place",
    [
        True,
        False,
    ],
)
@pytest.mark.parametrize(
    "elementwise_affine",
    [
        True,
        False,
    ],
)
def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode, in_place, elementwise_affine):
    _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype)

    h1 = _tensor.clone().requires_grad_(True)
    h2 = _tensor.clone().requires_grad_(True)

    # do
    do = torch.randn(bs, sl, hd, device=device, dtype=dtype)

    # reference (llama or gemma)
    ref_rms = reference(hidden_size=hd, elementwise_affine=elementwise_affine).to(device).to(dtype)
    ref_o = ref_rms(h1)
    ref_o.backward(do, retain_graph=True)

    # triton
    triton_rms = (
        LigerRMSNorm(
            hidden_size=hd,
            offset=offset,
            casting_mode=casting_mode,
            in_place=in_place,
            elementwise_affine=elementwise_affine,
        )
        .to(device)
        .to(dtype)
    )
    triton_o = triton_rms(h2)
    triton_o.backward(do, retain_graph=True)

    assert_verbose_allclose(ref_o, triton_o, atol=atol, rtol=rtol)
    if elementwise_affine:
        assert_verbose_allclose(ref_rms.weight.grad, triton_rms.weight.grad, atol=atol, rtol=rtol)
    print(f"{h1.grad=}")
    print(f"{h2.grad=}")
    assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol, max_print=20)


@pytest.mark.parametrize(
    "bs, sl, hd",
    [
        (2, 2, 8),
        # weird shapes
        (9, 7, 41),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        (torch.float32, 1e-4, 1e-6),
        (torch.bfloat16, 2e-1, 2e-2),
    ],
)
@pytest.mark.parametrize(
    "reference, offset, casting_mode",
    [
        (LlamaRMSNorm, 0.0, "llama"),
        (GemmaRMSNorm, 1.0, "gemma"),
    ],
)
@pytest.mark.parametrize(
    "elementwise_affine",
    [
        True,
        False,
    ],
)
def test_correctness_functional(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode, elementwise_affine):
    # h
    _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype)

    h1 = _tensor.clone().requires_grad_(True)
    h2 = _tensor.clone().requires_grad_(True)

    if elementwise_affine:
        w = torch.randn(hd, device=device, dtype=dtype)
    else:
        w = None

    y1 = liger_rms_norm(X=h1, W=w, eps=1e-6, offset=offset, casting_mode=casting_mode)
    y2 = LigerRMSNormFunction.apply(h2, w, 1e-6, offset, casting_mode)

    assert torch.allclose(y1, y2, atol=atol, rtol=rtol)

    grad = torch.randn_like(y2)

    y1.backward(grad)
    y2.backward(grad)

    assert torch.allclose(h1.grad, h2.grad, atol=atol, rtol=rtol)


def _test_dtensor_rms_norm(rank, world_size, bs, sl, hd, dtype, atol, rtol, offset, casting_mode, file_name):
    torch.distributed.init_process_group(
        backend=infer_comm_backend(),
        init_method=f"file://{file_name}",
        rank=rank,
        world_size=world_size,
    )
    device = f"{infer_device()}:{rank}" if infer_device() != "cpu" else "cpu"
    device_mesh = torch.distributed.device_mesh.init_device_mesh(
        infer_device(), mesh_shape=(world_size,), mesh_dim_names=("tp",)
    )
    t = torch.randn(bs, sl, hd, device=device, dtype=dtype, requires_grad=True)
    dt = torch.distributed.tensor.distribute_tensor(
        t,
        device_mesh=device_mesh,
        placements=[torch.distributed.tensor.Shard(2)],
    )
    w = torch.randn(hd, device=device, dtype=dtype, requires_grad=True)
    w1 = w.detach().clone()
    w2 = w.detach().clone()

    y1 = liger_rms_norm(X=dt, W=w1, eps=1e-6, offset=offset, casting_mode=casting_mode)
    y2 = liger_rms_norm(X=t, W=w2, eps=1e-6, offset=offset, casting_mode=casting_mode)
    torch.testing.assert_close(y1, y2, atol=atol, rtol=rtol)

    grad = torch.randn_like(y2)
    dgrad = torch.distributed.tensor.distribute_tensor(
        grad,
        device_mesh=device_mesh,
        placements=[torch.distributed.tensor.Shard(2)],
    )

    y1.backward(dgrad)
    y2.backward(grad)
    torch.testing.assert_close(w1.grad, w2.grad, atol=atol, rtol=rtol)
    torch.testing.assert_close(dt.grad, t.grad, atol=atol, rtol=rtol)


@pytest.mark.xfail(
    torch.cuda.device_count() < 8,
    reason="Pending multi-GPU host support. This test is expected to pass when run with multi-GPU host.",
)
@pytest.mark.parametrize(
    "world_size, bs, sl, hd",
    [
        (4, 2, 2, 8),
        (8, 9, 7, 64),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        (torch.float32, 1e-4, 1e-6),
        (torch.bfloat16, 2e-1, 2e-2),
    ],
)
@pytest.mark.parametrize(
    "offset, casting_mode",
    [
        (0.0, "llama"),
        (1.0, "gemma"),
    ],
)
def test_dtensor_rms_norm(world_size, bs, sl, hd, dtype, atol, rtol, offset, casting_mode):
    with tempfile.NamedTemporaryFile() as f:
        mp.spawn(
            _test_dtensor_rms_norm,
            args=(world_size, bs, sl, hd, dtype, atol, rtol, offset, casting_mode, f.name),
            nprocs=world_size,
            join=True,
        )