test_batched_linear.py 11.7 KB
Newer Older
yuguo's avatar
yuguo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

from collections import OrderedDict
import math
import os
from typing import Dict, List, Tuple, Optional
import pytest
import copy
import random

import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.utils.cpp_extension import IS_HIP_EXTENSION

from transformer_engine.pytorch.fp8 import (
    FP8GlobalStateManager,
    fp8_autocast,
    fp8_model_init,
)
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
    attention_mask_func,
    is_bf16_compatible,
)
from transformer_engine.pytorch import (
    DotProductAttention,
    LayerNormLinear,
    LayerNormMLP,
    Linear,
    GroupedLinear,
    BatchedLinear,
    MultiheadAttention,
    RMSNorm,
    TransformerLayer,
    LayerNorm,
    Fp8Padding,
    Fp8Unpadding,
)
from transformer_engine.pytorch import torch_version
44
from transformer_engine.pytorch.attention.inference import InferenceParams
yuguo's avatar
yuguo committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm, batchgemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe
import transformer_engine_torch as tex

# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()

sm_80plus = get_device_compute_capability() >= (8, 0)

seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

if torch_version() >= (2, 7, 0):
    torch._dynamo.config.recompile_limit = 16
else:
    torch._dynamo.config.cache_size_limit = 16

class ModelConfig:
    def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
        self.hidden_size = hidden_size
        self.eps = eps
        self.num_attention_heads = num_attention_heads
        self.embed = embed
        self.num_layers = num_layers
        self.seq_len = seq_len


model_configs = {
    "small": ModelConfig(128, 1e-5, 8, 36, 4, 128),
    "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}

model_configs_inference = {
    # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
    "126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
}
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]

param_types = [torch.float32, torch.float16]
if is_bf16_compatible():  # bf16 requires sm_80 or higher
    param_types.append(torch.bfloat16)

batch_sizes = [1, 2]

all_boolean = [True, False]

all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]

all_normalizations = ["LayerNorm", "RMSNorm"]

mask_types = ["causal", "no_mask"]

fp8_recipes = [
    recipe.MXFP8BlockScaling(),
    recipe.DelayedScaling(),
    recipe.Float8CurrentScaling(),
]


def get_causal_attn_mask(sq: int) -> torch.Tensor:
    return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()


def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
    """Estimated numerical error for a datatype

    Based on tolerances for torch.testing.assert_close.

    """
    if dtype == torch.float32:
        return dict(rtol=1.3e-6, atol=1e-5)
    if dtype == torch.float16:
        return dict(rtol=1e-3, atol=1e-5)
    if dtype == torch.bfloat16:
        return dict(rtol=1.6e-2, atol=1e-5)
    raise ValueError(f"Unsuppored dtype ({dtype})")


def assert_allclose(
    l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
) -> bool:
    """Ensures two lists are equal."""
    assert len(l1) == len(l2), "Unequal number of outputs."
    for i, (t1, t2) in enumerate(zip(l1, l2)):
        tols = dict(atol=atol)
        if rtol is not None:
            tols["rtol"] = rtol
        result = torch.allclose(t1, t2, **tols)
        if not result:
            diff = torch.abs(t1 - t2)
            tol = atol + (rtol * torch.abs(t2))
            exceed_mask = diff > tol
            if exceed_mask.any():
                indices = torch.nonzero(exceed_mask, as_tuple=True)
                max_diff = diff[exceed_mask].max()
                max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
                max_location = [idx[max_idx].item() for idx in indices]
                msg = (
                    f"Outputs not close enough in tensor at idx={i}. "
                    f"Maximum difference at location {max_location} "
                    f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
                    f"(diff {max_diff.item()})."
                )
            raise AssertionError(msg)


def reset_rng_states() -> None:
    """revert back to initial RNG state."""
    torch.set_rng_state(_cpu_rng_state)
    torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()

def _test_batched_linear_accuracy(
173
    block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
yuguo's avatar
yuguo committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
):
    reset_rng_states()
    if fp8:
        FP8GlobalStateManager.reset()

    inp_hidden_states = torch.randn(
        (config.seq_len, bs, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    inp_hidden_states.retain_grad()

    assert config.seq_len % num_gemms == 0
    m_splits = torch.tensor([config.seq_len // num_gemms for i in range(num_gemms)])
    assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms

    with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
        if isinstance(block, BatchedLinear):
            m_splits = m_splits * bs
            out = block(inp_hidden_states, m_splits.tolist())
        else:
            out = torch.cat(
                [
                    block[i](inp)
                    for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
                ]
            )
    loss = out.sum()
    loss.backward()
204
205
206
207
208
209
    if delay_wgrad_compute:
        if isinstance(block, BatchedLinear):
            block.backward_dw()
        else:
            for i in range(num_gemms):
                block[i].backward_dw()
yuguo's avatar
yuguo committed
210
211
212

    torch.cuda.synchronize()
    outputs = [out, inp_hidden_states.grad]
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    for p in block.parameters():
       if p.requires_grad:
           if isinstance(block, BatchedLinear):
               if getattr(p, "main_grad", None) is not None:
                   for j in range(batch_num):
                        outputs.append(p.main_grad[p.main_grad.shape[0] // batch_num * j : p.main_grad.shape[0] // batch_num * (j + 1)])
                        assert p.grad is None  # grad should be None if fuse_wgrad_accumulation is True
               else:
                   for j in range(batch_num):
                        outputs.append(p.grad[p.grad.shape[0] // batch_num * j : p.grad.shape[0] // batch_num * (j + 1)])
           else:
               if getattr(p, "main_grad", None) is not None:
                   outputs.append(p.main_grad)
                   assert p.grad is None  # grad should be None if fuse_wgrad_accumulation is True
               else:
                   outputs.append(p.grad)
yuguo's avatar
yuguo committed
229
230
231
232
233
234
235
236
237
238
    return outputs

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [4, 8])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
239
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
yuguo's avatar
yuguo committed
240
241
242
243
244
245
246
247
248
def test_batched_linear_accuracy(
    dtype,
    num_gemms,
    bs,
    model,
    fp8,
    recipe,
    fp8_model_params,
    fuse_wgrad_accumulation,
249
    delay_wgrad_compute,
yuguo's avatar
yuguo committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    parallel_mode=None,
):
    batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    if recipe.mxfp8() and not mxfp8_available:
        pytest.skip(reason_for_no_mxfp8)
    if fp8 and recipe.mxfp8():  # TODO(ksivamani): debug mismatches
        pytest.skip("MXFP8 unsupported for batched linear.")
    if fp8 and recipe.float8_current_scaling():
        pytest.skip("Float8 Current Scaling unsupported for batched linear.")

    config = model_configs[model]
    if config.seq_len % 16 != 0 and fp8:
        pytest.skip("FP8 requires sequence length to be divisible by 16.")

    with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
        batched_linear = BatchedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            bias=False,
            params_dtype=dtype,
            parallel_mode=parallel_mode,
            device="cuda",
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
276
            delay_wgrad_compute=delay_wgrad_compute,
yuguo's avatar
yuguo committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        ).eval()
        sequential_linear = torch.nn.ModuleList(
            [
                Linear(
                    config.hidden_size,
                    4 * config.hidden_size,
                    bias=False,
                    params_dtype=dtype,
                    parallel_mode=parallel_mode,
                    device="cuda",
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
                ).eval()
                for _ in range(num_gemms)
            ]
        )

    # Share params
    with torch.no_grad():
        for i in range(num_gemms // batch_num):
            weight = getattr(batched_linear, f"weight{i}").clone()
            # bias = getattr(batched_linear, f"bias{i}").clone()
            if fuse_wgrad_accumulation:
                weight_i = getattr(batched_linear, f"weight{i}")
                weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
            for j in range(batch_num):
                sequential_linear[i * batch_num + j].weight = Parameter(weight[weight.shape[0] // batch_num * j : weight.shape[0] // batch_num * (j + 1)].clone())
                # sequential_linear[i * batch_num + j].bias = Parameter(bias[bias.shape[0] // batch_num * j : bias.shape[0] // batch_num * (j + 1)].clone())
                if fuse_wgrad_accumulation:
                    sequential_linear[i * batch_num + j].weight.main_grad = weight_i.main_grad[weight_i.main_grad.shape[0] // batch_num * j : weight_i.main_grad.shape[0] // batch_num * (j + 1)].clone()

    outputs_ref = _test_batched_linear_accuracy(
308
        sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
yuguo's avatar
yuguo committed
309
310
    )
    outputs = _test_batched_linear_accuracy(
311
        batched_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
yuguo's avatar
yuguo committed
312
313
314
315
316
317
318
    )

    # Shoule be bit-wise match
    for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
        torch.testing.assert_close(o, o_ref, rtol=6e-3, atol=6e-3)

if __name__ == "__main__":
319
    test_batched_linear_accuracy(torch.float32, 2, 1, "126m", False, recipe.Float8CurrentScaling(), True, True, True)