test_onnx_export.py 57.8 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""
This file contains tests for exporting TransformerEngine models to ONNX.
7
8
9
10
11
12
13
14

The purpose of these tests is validation that TE models are converted to their correct ONNX
representation. Toward this end, each test captures the output of a TE module forward pass,
converts the TE module to ONNX, and uses ONNX Runtime (ORT) to execute the ONNX graph and
validate the output against TE's output.

Until FP8 is introduced to the ONNX standard, FP8 QuantizeLinear/DequantizeLinear is implemented
using custom ORT operations.
15
16
17
18
19
20

To run many repetitive tests use pytest-loop:
    $ python3 -m pip install pytest-loop
    $ pytest --loop 1000 tests/pytorch/test_onnx_export.py::test_export_layernorm

For reproducability use: torch.manual_seed(0)
21
22
"""

23

24
import os
25
import tempfile
26
27
28
29
30
31
import pytest
import warnings
import numpy as np
import onnxruntime as ort
import torch
from torch import nn as nn
32
from typing import Optional, Union, Tuple, List
33
34
35
36
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, fp8_gelu, cast_to_fp8, cast_from_fp8
37
from transformer_engine.pytorch.module.base import get_workspace
38
39
40
import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs
from transformer_engine.pytorch.utils import get_default_init_method
41
from transformer_engine.pytorch.export import is_in_onnx_export_mode
42
from transformer_engine.pytorch.fp8 import is_fp8_available
43

44
# Global test configuration knobs.
45

46
# Enable this to serialize test inputs and outputs to file (as a Polygraphy RunResults instance).
47
48
SAVE_TEST_IO = bool(int(os.getenv("NVTE_ONNX_EXPORT_SAVE_TEST_IO", "0")))

49
50
51
52
53
if SAVE_TEST_IO:
    from polygraphy.json import save_json
    from polygraphy.comparator import RunResults

# The directory where generated ONNX test models are stored.
Neta Zmora's avatar
Neta Zmora committed
54
55
56
NVTE_TEST_ARTIFACTS_DIR = os.environ.get('NVTE_TEST_ARTIFACTS_DIR')
NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(tempfile.gettempdir(), "./gen_onnx_models")

57
58
59

# The directory where this file is stored.
TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
60
61
62
63

# ScaledUpperTriangMaskedSoftmax is exported via ONNX::Trilu which was introduced in opset 14.
TRILU_OPSET = 14
# Opset used in the ONNX files generated by the tests.
Neta Zmora's avatar
Neta Zmora committed
64
OPSET = 17
65
66
assert OPSET >= TRILU_OPSET

67
68
69
# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT).
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")

70
71
fp8_available, reason_for_no_fp8 = is_fp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
72

Neta Zmora's avatar
Neta Zmora committed
73

74
75
76
77
78
79
80
81
82
83
84
85
@pytest.fixture()
def seed_default_rng():
    """Reseed the PRNG for test reproducibility"""
    torch.random.seed()


@pytest.fixture()
def set_max_seq_len(max_seq_len=128):
    """Set the maximum sequence length that can be used for attention masking"""
    os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}"


86
87
88
89
90
91
92
93
94
95
def create_fp8_recipe():
    return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)


def do_export(
    model: torch.nn.Module,
    inp: torch.Tensor,
    fname: str,
    use_fp8: bool=True,
    opset: int=OPSET,
96
97
98
    input_names: List[str]=None,
    output_names: List[str]=None,
    dynamic_axes: List[str]=None
99
100
101
):
    """Export to ONNX"""
    fp8_recipe = create_fp8_recipe()
102
103
    input_names = input_names or ["input"]
    output_names = output_names or ["output"]
104
105
106
107
108
109
110
111
112

    with torch.inference_mode(), te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings():
        warnings.filterwarnings(
            action='ignore',
            category=torch.jit.TracerWarning,
            module=r'.*'
        )

        model.cuda().eval()
Neta Zmora's avatar
Neta Zmora committed
113
114
        os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True)
        fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
115

116
        inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,)
117
118
119
120
        assert len(inps) == len(input_names)
        inds_to_del = [i for i in range(len(inps)) if inps[i] is None]
        input_names = [input_names[i] for i in range(len(inps)) if i not in inds_to_del]

121
        with te.onnx_export(True):
122
123
124
125
126
            torch.onnx.export(
                model,
                inps,
                fname,
                verbose=True,
127
                dynamic_axes=dynamic_axes,
128
129
130
                opset_version=opset,
                input_names=input_names,
                output_names=output_names,
131
                do_constant_folding=True,
132
                operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH)
133
134
135


def to_numpy(tensor):
136
137
138
139
140
    if isinstance(tensor, torch.Tensor):
        if tensor.dtype == torch.bfloat16:
            tensor = tensor.type(torch.float32)
        tensor = tensor.detach().cpu().numpy()
    return tensor
141
142


143
144
145
146
147
def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int):
    """Initialize the FP8 quantization scales in module"""
    NB_SCALES_PER_GEMM = 3  # One scale per: input, weights, and output GEMM tensors.
    nb_total_scales = num_gemms * NB_SCALES_PER_GEMM
    module.fp8_init(num_gemms)
148
    module.fp8_meta["scaling_fwd"].scale = torch.ones(
149
        nb_total_scales, dtype=torch.float32, device="cuda") / scale
150
    module.fp8_meta["scaling_fwd"].scale_inv = torch.ones(
151
        nb_total_scales, dtype=torch.float32, device="cuda") * scale
152
153
154


def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool):
155
    """Transformer Engine forward propagation."""
156
157
158
159
160
    fp8_recipe = create_fp8_recipe()
    with torch.inference_mode(), te.fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings():
        te_outputs = model(*inps if isinstance(inps, tuple) else (inps,))
        if not isinstance(te_outputs, tuple):
            te_outputs = (te_outputs,)
161
        return te_outputs
162
163


164
165
166
167
168
169
def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname):
    """ Compare ORT and TE outputs."""
    assert len(onnx_outputs) == len(te_outputs)
    # Compare ORT and PyTorch outputs.
    for onnx_output, te_output in zip(onnx_outputs, te_outputs):
        # np.isclose: abs(a - b) <= (atol + rtol * abs(b))
170
171
        te_output = to_numpy(te_output)
        onnx_output = to_numpy(onnx_output)
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol)
        mismatches = ac.nonzero()
        mismatched_ids = [loc for loc in zip(*mismatches)]
        if mismatched_ids:
            # Log some information in case of error.
            print("*" * 100)
            nb_errors = len(mismatched_ids)
            nb_vals = min(nb_errors, max_errors_printed)
            print(f"Detected {nb_errors} diverging values (output shape={onnx_output.shape})")
            print(f"Showing first {nb_vals} errors (ONNX -- TE):")
            abs_err = np.abs(onnx_output - te_output)
            errors = abs_err[mismatches]
            for loc in mismatched_ids[:nb_vals]:
                ref = te_output[loc]
                print(f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} > {atol + rtol * abs(ref)}")
            print(f"Max error: {np.max(errors)}")
            if nb_errors > allow_cnt_errors:
                raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors")

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def serialize_inputs_outputs(
    fname: str,
    inputs: Union[Tuple[torch.Tensor], torch.Tensor],
    te_outputs: List[torch.Tensor],
    input_names: Optional[List[str]] = None,
    output_names: Optional[List[str]] = None,
):
    if not SAVE_TEST_IO:
        return

    fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)

    input_names = input_names or ["input"]
    output_names = output_names or ["output"]
    inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,)
    named_inputs = zip(input_names, inputs)
    input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}]
    json_fname = fname[:-len(".onnx")] + "_inputs.json"
    save_json(input_data, json_fname, description="custom input data")

    json_fname = fname[:-len(".onnx")] + "_output.json"
    named_outputs = zip(output_names, te_outputs)
    output_data = {k: v.cpu() for k, v in named_outputs if v is not None}
    custom_outputs = RunResults()
    custom_outputs.add([output_data], runner_name="custom_runner")
    custom_outputs.save(json_fname)

218

219
220
221
222
223
224
225
226
def validate_result(
    fname: str,
    inps: Union[Tuple[torch.Tensor], torch.Tensor],
    model: torch.nn.Module,
    atol: float=1.e-8, # np.isclose default atol
    rtol: float=1.e-5, # np.isclose default rtol
    max_errors_printed: int=10,
    is_fp8: bool=False,
227
    allow_cnt_errors: int=0,
228
229
230
    input_names: List[str]=None,
    output_names: List[str]=None,
    te_outputs: List[torch.Tensor]=None,
231
):
232
233
234
235
236
237
238
239
240
241
242
    """Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX
    representation using ONNX Runtime (ORT) and ensure they are close.

    The purpose of the output comparison is to validate that TE models are converted to
    their correct ONNX representation by testing that TE and ORT outputs match within some
    small threshold (allowing for finite precision errors).

    Argument `allow_cnt_errors` reduces test failure noise due to spurious errors by ignoring,
    a very small number (0-3) of outliers. This is fine to do because these outliers are due to
    small kernel implementation differences between TE and ORT and do not imply an incorrect ONNX
    representation (the tests assume both ORT or TE kernels are correct).
243
244

    Argument `te_outputs` can be used to provide pre-computed TE outputs.
245
    """
246
247
248
249
250
251
252
253
254
255

    def create_ort_session(fname: str, is_fp8: bool):
        def load_custom_ops(session_opts: ort.SessionOptions):
            """For FP8 validation with ORT we need to load our custom FP8 Q/DQ extension."""
            if not os.path.exists(ORT_CUSTOM_OPS_LIB):
                raise FileNotFoundError(f"Unable to find {ORT_CUSTOM_OPS_LIB}")
            session_opts.register_custom_ops_library(ORT_CUSTOM_OPS_LIB)
            print("registered custom FP8 Q/DQ ops!")

        """Create an ONNX Runtime session for validation."""
Neta Zmora's avatar
Neta Zmora committed
256
        kwargs = {}
257
258
259
        if is_fp8:
            sess_options = ort.SessionOptions()
            load_custom_ops(sess_options)
260
261
262
            kwargs["sess_options"] = sess_options

        s = ort.InferenceSession(fname, **kwargs)
263
264
        return s

265
266
267
268
269
    def create_ort_input_dict(session, inputs):
        inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,)
        input_names = [x.name for x in session.get_inputs()]
        inps = [to_numpy(x) for x in inputs if x is not None]
        inp_dict = dict(zip(input_names, inps))
270
271
        return inp_dict

272
273
    input_names = input_names or ["input"]
    output_names = output_names or ["output"]
274

275
    # Run ORT session and TE model.
Neta Zmora's avatar
Neta Zmora committed
276
    fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
277
278
    if not te_outputs:
        te_outputs = te_infer(model, inps, is_fp8)
Neta Zmora's avatar
Neta Zmora committed
279
280
281
    ort_s = create_ort_session(fname, is_fp8)
    input_feed = create_ort_input_dict(ort_s, inps)
    onnx_outputs = ort_s.run(None, input_feed=input_feed)
282
    compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname)
283
284
285
286
287
288
289
290
291
292


def create_meta(scale_factor: float, size: int=1):
    meta = tex.FP8TensorMeta()
    meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
    meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
    meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
    return meta


293
294
295
296
def dtype2str(dtype: torch.dtype, fake_bf16_io=False):
    if fake_bf16_io:
        assert dtype == torch.bfloat16
        return "_fake_bf16"
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    return {
        torch.float32: "_fp32",
        torch.float16: "_fp16",
        torch.bfloat16: "_bf16",
    }[dtype]


def as_te_type(dtype: torch.dtype):
    return {
        torch.float32: tex.DType.kFloat32,
        torch.float16: tex.DType.kFloat16,
        torch.bfloat16: tex.DType.kBFloat16,
    }[dtype]


def get_attn_mask_str(use_mask, attn_mask_type):
    # See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names.
    if attn_mask_type is None:
        return "_mask" if use_mask else "_no-mask"
    attn_mask_str = "_padding-no-mask"
    attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str
    attn_mask_str = "_padding-mask" if use_mask and attn_mask_type == "padding" else attn_mask_str
    return attn_mask_str


Neta Zmora's avatar
Neta Zmora committed
322
323
324
325
326
"""
Tests cases begin here.
"""


327
@skip_FP8
Neta Zmora's avatar
Neta Zmora committed
328
329
@pytest.mark.parametrize("scale_factor", [1, 224])
@pytest.mark.parametrize(
330
331
332
333
334
    "precision,             atol", [
    [torch.float32,         1e-7],
    [torch.float16,         1e-7],
    [torch.bfloat16,        5e-3],
    ["fake-torch.bfloat16", 5e-3],
335
])
336
def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype):
337
338
339
340
    fake_bf16_io = precision == "fake-torch.bfloat16"
    # reset precision to torch.bfloat16 after capturing fake BF16 mode
    precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision

341
    class TestFP8_QDQ(nn.Module):
Neta Zmora's avatar
Neta Zmora committed
342
        def __init__(self, fake_bf16_io):
343
344
345
346
347
            super().__init__()
            self.fp8_tensor = 0
            self.meta = create_meta(scale_factor)
            self.highprec_type = as_te_type(precision)
            self.fp8_type = tex.DType.kFloat8E4M3
Neta Zmora's avatar
Neta Zmora committed
348
            self.fake_bf16_io = fake_bf16_io
349
350
351
352
353
354
355
356
357
358
359
360
361
362

        def forward(self, inp):
            ret = cast_to_fp8(
                inp,
                self.meta,
                self.fp8_tensor,
                self.fp8_type)

            ret = cast_from_fp8(
                ret,
                self.meta,
                self.fp8_tensor,
                self.fp8_type,
                self.highprec_type)
Neta Zmora's avatar
Neta Zmora committed
363
364
            if self.fake_bf16_io:
                ret = ret.type(torch.float32)
365
366
367
368
369
            return ret

    # Set dimensions (these are arbitrary).
    in_features = 64
    hidden_size = 256
Neta Zmora's avatar
Neta Zmora committed
370
371
    inp = torch.randn(hidden_size, in_features, device="cuda",
        dtype=torch.float if fake_bf16_io else precision)
372
    high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
373
    fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx"
Neta Zmora's avatar
Neta Zmora committed
374
    model = TestFP8_QDQ(fake_bf16_io)
375

376
    do_export(model, inp, fname)
377
378
379
380
    te_outputs = te_infer(model, inp, is_fp8=True)
    serialize_inputs_outputs(fname, inp, te_outputs)
    if fake_bf16_io or precision != torch.bfloat16:
        validate_result(fname, inp, model, atol=atol, is_fp8=True, te_outputs=te_outputs)
381
382
383
384

@skip_FP8
@pytest.mark.parametrize("scale_factor", [448])
@pytest.mark.parametrize(
385
386
387
388
389
    "precision,             atol", [
    [torch.float32,         1e-5],
    [torch.float16,         1e-5],
    [torch.bfloat16,        5e-3],
    ["fake-torch.bfloat16", 5e-3]
390
391
])
def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float):
392
393
394
395
    fake_bf16_io = precision == "fake-torch.bfloat16"
    # reset precision to torch.bfloat16 after capturing fake BF16 mode
    precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision

396
    class TestFP8_Gelu(nn.Module):
Neta Zmora's avatar
Neta Zmora committed
397
        def __init__(self, fake_bf16_io):
398
399
400
401
402
            super().__init__()
            self.fp8_tensor = 0
            self.meta = create_meta(scale_factor)
            self.highprec_type = as_te_type(precision)
            self.fp8_type = tex.DType.kFloat8E4M3
Neta Zmora's avatar
Neta Zmora committed
403
            self.fake_bf16_io = fake_bf16_io
404
405
406
407
408
409
410
411
412
413
414
415
416

        def forward(self, inp):
            ret = fp8_gelu(
                inp,
                self.meta,
                self.fp8_tensor,
                self.fp8_type)
            ret = cast_from_fp8(
                ret,
                self.meta,
                self.fp8_tensor,
                self.fp8_type,
                self.highprec_type)
Neta Zmora's avatar
Neta Zmora committed
417
418
            if self.fake_bf16_io:
                ret = ret.type(torch.float32)
419
420
421
422
423
            return ret

    # Set dimensions (these are arbitrary).
    in_features = 64
    hidden_size = 256
Neta Zmora's avatar
Neta Zmora committed
424
425
    inp = torch.randn(hidden_size, in_features, device="cuda",
        dtype=torch.float if fake_bf16_io else precision)
426
    high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
427
    fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx"
Neta Zmora's avatar
Neta Zmora committed
428
    model = TestFP8_Gelu(fake_bf16_io)
429
    do_export(model, inp, fname)
430
431
432
433
    te_outputs = te_infer(model, inp, is_fp8=True)
    serialize_inputs_outputs(fname, inp, te_outputs)
    if fake_bf16_io or precision != torch.bfloat16:
        validate_result(fname, inp, model, rtol=0, atol=atol, is_fp8=True, allow_cnt_errors=2, te_outputs=te_outputs)
434
435
436
437
438
439


@pytest.mark.parametrize("scale_factors",
    [(224, 224,),
])
@pytest.mark.parametrize(
440
441
442
443
444
445
446
    "precision,             use_fp8, use_bias, use_gelu", [
    (torch.float32,         False,   False,    False),
    (torch.float16,         False,   False,    False),
    (torch.float32,         False,   True,     False),
    (torch.float16,         False,   True,     False),
    (torch.float32,         False,   True,     True),
    (torch.float16,         False,   True,     True),
447
448

    # For FP8 GEMM GeLU is not used.
449
450
    (torch.float32,         True,    False,    False),
    (torch.float16,         True,    False,    False),
451
    # When enabling bias we must use float16 or bfloat16 (because of kernel limitations)
452
453
    (torch.float16,         True,    True,     False),
    (torch.bfloat16,        True,    True,     False),
454
455
])
def test_export_gemm(
456
    seed_default_rng,
457
458
459
460
461
462
463
    precision, # Precision of inputs, weights, output and bias
    use_fp8,
    use_bias,
    use_gelu,
    scale_factors
):
    # Skip FP8 tests on non-hopper devices
464
465
    if use_fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545

    class TestFP8_GEMM(nn.Module):
        def __init__(self, precision, use_bias, gelu, scale_factors):
            super().__init__()
            self.use_bias = use_bias
            self.gelu = gelu
            self.precision = precision

            self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
            self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
            nb_inp_scales, nb_weight_scales = 1, out_features
            act_scale_factor, weight_scale_factor = scale_factors
            self.meta_inp = create_meta(act_scale_factor, nb_inp_scales)
            self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales)

            bias_size = nb_weight_scales
            self.bias = torch.randn(bias_size, dtype=precision, device="cuda")
            self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda")

            self.inp_type = tex.DType.kFloat8E4M3
            self.weights_type = tex.DType.kFloat8E4M3
            self.outp_type = precision

        def forward(self, inp, weight):
            inp_fp8 = cast_to_fp8(
                inp,
                self.meta_inp,
                self.fp8_tensor_inp,
                self.inp_type)

            weight_fp8 = cast_to_fp8(
                weight,
                self.meta_weight,
                self.fp8_tensor_weight,
                self.weights_type)

            ret = fp8_gemm(
                weight_fp8,
                self.meta_weight.scale_inv,
                self.fp8_tensor_weight,
                self.inp_type,
                inp_fp8,
                self.meta_inp.scale_inv,
                self.fp8_tensor_inp,
                self.weights_type,
                self.outp_type,
                get_workspace(),
                bias=self.bias,
                use_bias=self.use_bias,
                use_split_accumulator=False)
            return ret

    class Test_GEMM(nn.Module):
        def __init__(self, precision, use_bias=False, gelu=False):
            super().__init__()
            self.use_bias = use_bias
            self.gelu = gelu
            self.precision = precision
            bias_size = out_features
            self.bias = torch.randn(bias_size, dtype=precision, device="cuda")
            self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda")

        def forward(self, inp, weight):
            outp_type = self.precision

            # note: due to logic in lines 104:116 and L129 in cpp_extensions.py
            # it appears either bias OR gelu can be activated, not both
            ret, _, _ = gemm(
                weight,
                inp,
                outp_type,
                get_workspace(),

                # test bias
                bias=self.bias,
                use_bias=self.use_bias,

                # test gelu
                gelu=self.gelu,
                gelu_input=self.gelu_input,
Neta Zmora's avatar
Neta Zmora committed
546
547
                grad=False, # only True for backward pass
                accumulate=False,
548
549
550
551
552
553
554
555
556
            )
            return ret

    # If gelu is applied then bias must be added, as defined by TE kernel.
    if use_gelu: assert use_bias
    # Set dimensions (these are arbitrary).
    out_features = 128
    hidden_size = 256
    in_features = 64
557
558
    inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision)
    weight = torch.randn(out_features, in_features, device="cuda", dtype=precision)
559
560
561
562
563
    fp8_str = "_fp8" if use_fp8 else ""
    bias_str = "_bias" if use_bias else ""
    gelu_str = "_gelu" if use_gelu else ""
    high_prec_str = dtype2str(precision)
    fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx"
564
    input_names = ['input', 'weight']
565
566
    if use_fp8:
        model = TestFP8_GEMM(precision, use_bias, use_gelu, scale_factors)
567
        do_export(model, (inp, weight), fname, use_fp8, input_names=input_names)
568
569
570
571
572
        te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
        serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
        if precision != torch.bfloat16:
            validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2,
                is_fp8=True, input_names=input_names, te_outputs=te_outputs)
573
574
    else:
        model = Test_GEMM(precision, use_bias, use_gelu)
575
        do_export(model, (inp, weight), fname, use_fp8, input_names=input_names)
576
577
578
579
580
        te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
        serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
        if precision != torch.bfloat16:
            validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2,
                input_names=input_names, te_outputs=te_outputs)
581
582
583


@pytest.mark.parametrize("scale_factor", [448, 112])
584
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
585
586
587
588
589
590
591
592
593
594
595
@pytest.mark.parametrize(
    "use_fp8, precision,             atol", [
    [False,   torch.float32,         1e-7],
    [False,   torch.float16,         1e-7],
    [False,   torch.bfloat16,        1e-7],
    [False,   "fake-torch.bfloat16", 1e-7],
    [True,    torch.float32,         1e-7],
    [True,    torch.float16,         1e-7],
    [True,    torch.bfloat16,        1e-2],
    [True,    "fake-torch.bfloat16", 1e-2]
])
596
def test_export_layernorm(
597
    seed_default_rng,
598
599
    use_fp8: bool,
    scale_factor: float,
600
    precision: torch.dtype,
601
602
    zero_centered_gamma: bool,
    atol: float
603
):
604
605
606
607
    fake_bf16_io = precision == "fake-torch.bfloat16"
    # reset precision to torch.bfloat16 after capturing fake BF16 mode
    precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision

608
    # Skip FP8 tests on non-hopper devices
609
610
    if use_fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
611
612
613
614
615
616
617
618

    # Set dimensions (these are arbitrary).
    inp_shape = [64, 32]

    class Test_Layernorm(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            normalized_shape = torch.Size(inp.shape[1:])
619
620
621
622
            self.weight = torch.randn(*normalized_shape, device="cuda",
                dtype=torch.float if fake_bf16_io else precision)
            self.bias = torch.zeros(*normalized_shape, device="cuda",
                dtype=torch.float if fake_bf16_io else precision)
623
624
625
626
627
628
629
            self.eps = 1e-6 # An arbitrary small value

        def forward(self, inp):
            ret = texcpp.layernorm_fwd_inf(
                inp,
                self.weight,
                self.bias,
630
631
                self.eps,
                zero_centered_gamma)
632
633
634
635
636
637
            return ret

    class TestFP8_Layernorm(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            normalized_shape = torch.Size(inp.shape[1:])
638
639
640
641
            self.weight = torch.randn(*normalized_shape, device="cuda",
                dtype=torch.float32 if fake_bf16_io else precision)
            self.bias = torch.zeros(*normalized_shape, device="cuda",
                dtype=torch.float32 if fake_bf16_io else precision)
642
643
644
645
646
647
648
649
650
651
652
653
654
655
            self.eps = 1e-6 # An arbitrary small value

            self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
            self.meta = create_meta(scale_factor)
            self.fp8_type = tex.DType.kFloat8E4M3

        def forward(self, inp):
            ret = texcpp.layernorm_fwd_fp8_inf(
                inp,
                self.weight,
                self.bias,
                self.eps,
                self.meta,
                self.fp8_tensor,
656
657
                self.fp8_type,
                zero_centered_gamma)
658
659
660
661
662
663

            ret = cast_from_fp8(
                ret,
                self.meta,
                self.fp8_tensor,
                self.fp8_type,
664
665
666
                as_te_type(precision))
            if fake_bf16_io:
                ret = ret.type(torch.float32)
667
668
            return ret

669
    inp = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision)
670
    model = TestFP8_Layernorm() if use_fp8 else Test_Layernorm()
671
    high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
672
673
674
    fp8_str = f"_fp8-{scale_factor}" if use_fp8 else ""
    fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx"
    do_export(model, inp, fname, use_fp8=use_fp8)
675
676
677
    te_outputs = te_infer(model, inp, is_fp8=use_fp8)
    serialize_inputs_outputs(fname, inp, te_outputs)
    if fake_bf16_io or precision != torch.bfloat16:
678
        validate_result(
679
        fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs)
680
681
682


@skip_FP8
683
@pytest.mark.parametrize("softmax_fn", [
684
685
686
    softmax_defs.ScaledUpperTriangMaskedSoftmax,
    softmax_defs.ScaledMaskedSoftmax,
    softmax_defs.ScaledSoftmax,
687
    te.softmax.FusedScaleMaskSoftmax,
688
689
])
# Softmax kernel only supports FP16 or BF16!
690
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"])
691
def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision):
692
693
694
695
    fake_bf16_io = precision == "fake-torch.bfloat16"
    # reset precision to torch.bfloat16 after capturing fake BF16 mode
    precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision

696
    class Test_Softmax(nn.Module):
697
        def __init__(self, softmax_fn, fake_bf16_io, mask_inp=False):
698
            super().__init__()
699
700
            self.softmax_fn = softmax_fn
            self.scale = 8 # arbitrary value
701
            self.mask_inp = mask_inp
702
            self.fused_scaled_softmax = None
703
            self.fake_bf16_io = fake_bf16_io
704
705
706
707
708
709
            if self.softmax_fn == te.softmax.FusedScaleMaskSoftmax:
                self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
                    attn_mask_type="causal",
                    mask_func=te.utils.attention_mask_func,
                    softmax_in_fp32=True,
                )
710
711

        def forward(self, inp, mask):
712
713
            if self.fused_scaled_softmax:
                ret = self.fused_scaled_softmax(inp, mask, self.scale)
714
            else:
715
716
717
718
                if self.mask_inp:
                    ret = self.softmax_fn.apply(inp, mask, self.scale)
                else:
                    ret = self.softmax_fn.apply(inp, self.scale)
719
720
            if self.fake_bf16_io:
                ret = ret.type(torch.float16)
721
722
723
724
725
726
            return ret

    # Set dimensions (these are arbitrary).
    in_features = 64
    hidden_size = 256
    mask = None
727
    input_names = ["input", "mask"]
728
    inp_shape = [hidden_size, in_features, in_features, in_features]
729
    if softmax_fn == softmax_defs.ScaledUpperTriangMaskedSoftmax:
730
731
        inp_shape = [hidden_size, in_features, in_features]
        kernel_str = "ScaledUpperTriangMaskedSoftmax"
732
        model = Test_Softmax(softmax_fn, fake_bf16_io)
733
    elif softmax_fn == softmax_defs.ScaledMaskedSoftmax:
734
735
736
737
        # Generate a random mask with 50% probability for 0 or 1.
        probs = 0.5 * torch.ones(hidden_size, 1, in_features, in_features, device="cuda", dtype=precision)
        mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
        kernel_str = "ScaledMaskedSoftmax"
738
        model = Test_Softmax(softmax_fn, fake_bf16_io, mask_inp=True)
739
    elif softmax_fn == softmax_defs.ScaledSoftmax:
740
        kernel_str = "ScaledSoftmax"
741
        model = Test_Softmax(softmax_fn, fake_bf16_io)
742
743
    elif softmax_fn == te.softmax.FusedScaleMaskSoftmax:
        kernel_str = "TorchSoftmax"
744
        model = Test_Softmax(softmax_fn, fake_bf16_io)
745
    input_tensor = torch.randn(*inp_shape, device="cuda")
746
747
748
    # WAR for BF16 test as ORT doesn't support BF16 IO: FP16 input for both BF16 and FP16 precision types
    input_tensor = input_tensor.half()
    high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
749
750
751
    fname = f"{kernel_str}{high_prec_str}.onnx"
    inp = (input_tensor, mask)
    do_export(model, inp, fname, input_names=input_names)
752
753
754
755
    te_outputs = te_infer(model, inp, is_fp8=False)
    serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
    if fake_bf16_io or precision != torch.bfloat16:
        validate_result(fname, inp, model, atol=1e-3, input_names=input_names, te_outputs=te_outputs)
756
757


758
759
760
# Test dynamically generated softmax mask.
# Softmax kernel only supports FP16 or BF16!
@skip_FP8
761
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"])
762
def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision):
763
764
765
766
    fake_bf16_io = precision == "fake-torch.bfloat16"
    # reset precision to torch.bfloat16 after capturing fake BF16 mode
    precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision

767
    class Test_Softmax(nn.Module):
768
        def __init__(self, use_onnx_mask_fn: bool, fake_bf16_io: bool):
769
770
            super().__init__()
            self.scale = 1 # arbitrary value
771
            self.fake_bf16_io = fake_bf16_io
772
773
774
775
776
777
778
779
780
781
782
            # Use NVTE_MASKED_SOFTMAX_FUSION to force TE to use forward_torch_softmax
            # even when is_in_onnx_export_mode()==False.
            os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0"
            self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
                attn_mask_type="causal",
                mask_func=te.utils.attention_mask_func,
                softmax_in_fp32=True,
            )

        def forward(self, inp, mask):
            ret = self.fused_scaled_softmax(inp, mask, self.scale)
783
784
            if self.fake_bf16_io:
                ret = ret.type(torch.float16)
785
786
787
788
789
790
791
792
            return ret

    # Set dimensions (these are arbitrary).
    in_features = 64
    hidden_size = 256
    mask = None
    inp_shape = [hidden_size, in_features, in_features, in_features]
    input_tensor = torch.randn(*inp_shape, device="cuda")
793
794
    # WAR for BF16 test as ORT doesn't support BF16 IO: FP16 input for both BF16 and FP16 precision types
    input_tensor = input_tensor.half()
795
    inp = (input_tensor, mask)
796
    high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
797
798
799

    # Compare the outputs of TE when using the default softmax mask
    # to the TE outputs produced when using the ONNX-compatible causal mask.
800
    model = Test_Softmax(use_onnx_mask_fn=False, fake_bf16_io=fake_bf16_io)
801
802
803
    te_outputs_default_mask = te_infer(model, inp, is_fp8=True)
    with te.onnx_export(True):
        # ONNX export mode forces use of the ONNX-compatible causal mask.
804
        model_onnx_mask = Test_Softmax(use_onnx_mask_fn=True, fake_bf16_io=fake_bf16_io)
805
806
807
808
809
810
811
812
813
814
        te_outputs_onnx_mask = te_infer(model_onnx_mask, inp, is_fp8=True)
    compare_outputs(te_outputs_default_mask, te_outputs_onnx_mask,
        atol=0, rtol=0, max_errors_printed=10, allow_cnt_errors=0, fname="softmax masking")

    # Compare the outputs of TE when using the default softmax mask
    # to the ORT ONNX outputs produced when using the ONNX-compatible causal mask.
    input_names = ["input", "mask"]
    kernel_str = "FusedScaleMaskSoftmax"
    fname = f"{kernel_str}{high_prec_str}.onnx"
    do_export(model, inp, fname, input_names=input_names)
815
816
    serialize_inputs_outputs(fname, inp, te_outputs=te_outputs_default_mask, input_names=input_names)
    if fake_bf16_io or precision != torch.bfloat16:
817
818
819
        validate_result(fname, inp, model_onnx_mask, atol=1e-3, input_names=input_names, te_outputs=te_outputs_default_mask)


820
821
822
823
824
@pytest.mark.parametrize("scale_factor", [1])
@pytest.mark.parametrize("use_fp8", [False, True])
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize(
825
826
827
828
829
    "precision,      use_bias",[
    (torch.float32,  False),
    (torch.float32,  True),
    (torch.float16,  False),
    (torch.float16,  True),
830
831
832
    # Todo: cannot configure BF16 when bias is disabled (ORT issue?)
    (torch.bfloat16, False),
    # Todo: cannot configure BF16 when bias is enabled (ORT issue?)
833
    (torch.bfloat16, True),
834
835
])
def test_export_linear(
836
    seed_default_rng,
837
838
839
840
841
842
843
    scale_factor: float,
    use_fp8: bool,
    use_bias: bool,
    return_bias: bool,
    precision: torch.dtype
):
    # Skip FP8 tests on non-hopper devices
844
845
    if use_fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886

    # Set dimensions (these are arbitrary).
    in_features = 64
    out_features = 256
    hidden_size = 256

    class Test_Linear(nn.Module):
        def __init__(self,
                in_features,
                out_features,
                use_bias,
                return_bias,
                precision
            ):
            super().__init__()
            self.linear = te.Linear(
                in_features,
                out_features,
                bias=use_bias,
                return_bias=return_bias,
                params_dtype=precision
            )

        def forward(self, inp):
            ret = self.linear(inp)
            return ret

    inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision)
    fp8_str = "_fp8" if use_fp8 else ""
    bias_str = "_bias" if use_bias else ""
    high_prec_str = dtype2str(precision)
    fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx"
    with te.fp8_autocast(enabled=use_fp8):
        model = Test_Linear(
            in_features,
            out_features,
            use_bias,
            return_bias,
            precision
        ).to(device='cuda')
        if use_fp8:
887
            set_layer_scale(model.linear, scale_factor, num_gemms=1)
888
        do_export(model, inp, fname, use_fp8)
889
890
        te_outputs = te_infer(model, inp, is_fp8=use_fp8)
        serialize_inputs_outputs(fname, inp, te_outputs)
891
892
893
894

        if precision in (torch.bfloat16, ):
            return
        if not use_fp8:
895
            validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
896
        else:
897
            validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8, te_outputs=te_outputs)
898
899
900
901
902
903
904
905


@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("use_fp8", [False, True])
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize(
906
907
908
909
910
911
912
    "precision,      use_bias",[
    (torch.float32,  False),
    (torch.float32,  True),
    (torch.float16,  True),
    (torch.float16,  False),
    (torch.bfloat16, True),
    (torch.bfloat16, False),
913
])
914
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
915
def test_export_layernorm_linear(
916
    seed_default_rng,
917
918
919
920
921
    scale_factor: float,
    use_fp8: bool,
    use_bias: bool,
    return_bias: bool,
    return_layernorm_output: bool,
922
923
    precision: torch.dtype,
    zero_centered_gamma: bool
924
925
):
    # Skip FP8 tests on non-hopper devices
926
927
    if use_fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
928
929
930
931
932
933
934
935
936
937
938

    # Set dimensions (these are arbitrary).
    in_features = 64
    out_features = 256
    hidden_size = 256

    inp = torch.randn(in_features, out_features, device="cuda", dtype=precision)
    fp8_str = "_fp8" if use_fp8 else ""
    bias_str = "_bias" if use_bias else ""
    high_prec_str = dtype2str(precision)
    fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx"
939

940
941
942
943
944
945
946
947
    with te.fp8_autocast(enabled=use_fp8):
        model = te.LayerNormLinear(
            hidden_size,
            3 * hidden_size,
            bias=use_bias,
            return_bias=return_bias,
            return_layernorm_output=return_layernorm_output,
            params_dtype=precision,
948
            zero_centered_gamma=zero_centered_gamma,
949
950
        ).to(device='cuda')
        if use_fp8:
951
            set_layer_scale(model, scale_factor, num_gemms=1)
952
        do_export(model, inp, fname, use_fp8)
953
954
955
956
        te_outputs = te_infer(model, inp, is_fp8=use_fp8)
        serialize_inputs_outputs(fname, inp, te_outputs)
        if precision in (torch.bfloat16, ):
            return
957
        if not use_fp8:
958
            validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
Neta Zmora's avatar
Neta Zmora committed
959
        elif precision != torch.bfloat16:
960
            validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8, te_outputs=te_outputs)
961
962
963
964
965
966
967
968


@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("use_fp8", [False, True])
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize(
969
970
971
972
973
974
975
    "precision,      use_bias",[
    (torch.float32,  False),
    (torch.float32,  True),
    (torch.float16,  True),
    (torch.float16,  False),
    (torch.bfloat16, True),
    (torch.bfloat16, False),
976
])
977
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
978
def test_export_layernorm_mlp(
979
    seed_default_rng,
980
981
982
983
984
    scale_factor: float,
    use_fp8: bool,
    use_bias: bool,
    return_bias: bool,
    return_layernorm_output: bool,
985
986
    precision: torch.dtype,
    zero_centered_gamma: bool
987
988
):
    # Skip FP8 tests on non-hopper devices
989
990
    if use_fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010

    # Set dimensions (these are arbitrary).
    in_features = 64
    out_features = 256
    hidden_size = 256
    ffn_hidden_size = 256

    inp = torch.randn(in_features, out_features, device="cuda", dtype=precision)
    fp8_str = "_fp8" if use_fp8 else ""
    bias_str = "_bias" if use_bias else ""
    high_prec_str = dtype2str(precision)
    fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}.onnx"
    with te.fp8_autocast(enabled=use_fp8):
        model = te.LayerNormMLP(
            hidden_size,
            ffn_hidden_size,
            bias=use_bias,
            return_bias=return_bias,
            return_layernorm_output=return_layernorm_output,
            params_dtype=precision,
1011
            zero_centered_gamma=zero_centered_gamma,
1012
1013
        ).to(device='cuda')
        if use_fp8:
1014
            set_layer_scale(model, scale_factor, num_gemms=2)
1015
        do_export(model, inp, fname, use_fp8)
1016
1017
1018
1019
        te_outputs = te_infer(model, inp, is_fp8=use_fp8)
        serialize_inputs_outputs(fname, inp, te_outputs)
        if precision in (torch.bfloat16, ):
            return
1020
        if not use_fp8:
1021
            validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
1022
        else:
1023
            validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8, te_outputs=te_outputs)
1024
1025
1026

@skip_FP8
@pytest.mark.parametrize(
1027
1028
1029
1030
1031
1032
1033
1034
1035
    "precision,      use_mask, attn_mask_type", [
    (torch.float32,  False,    None),      # calls forward_torch_softmax
    (torch.float32,  True,     None),      # calls forward_torch_softmax
    (torch.float16,  False,    "causal"),  # calls ScaledUpperTriangMaskedSoftmax
    (torch.float16,  True,     "padding"), # calls ScaledMaskedSoftmax
    (torch.float16,  False,    "padding"), # calls ScaledSoftmax
    (torch.bfloat16, False,    "causal"),  # calls ScaledUpperTriangMaskedSoftmax
    (torch.bfloat16, True,     "padding"), # calls ScaledMaskedSoftmax
    (torch.bfloat16, False,    "padding"), # calls ScaledSoftmax
1036
1037
])
def test_export_core_attention(
1038
1039
    seed_default_rng,
    set_max_seq_len,
1040
1041
1042
1043
1044
    precision: torch.dtype,
    use_mask: bool,
    attn_mask_type: str,
):
    # Set dimensions (these are arbitrary).
1045
1046
    seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
    qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels)
1047
1048
1049
1050

    query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
    key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
    value_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
1051
    input_names = ["query", "key", "value", "attention_mask"]
1052
1053
1054
1055
1056
1057
1058
1059
1060
    attention_mask = None
    if use_mask:
        # Generate a random mask with 50% probability for 0 or 1.
        probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision)
        attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
    inp = (query_layer, key_layer, value_layer, attention_mask)

    mask_str = get_attn_mask_str(use_mask, attn_mask_type)
    high_prec_str = dtype2str(precision)
1061
    fname = f"te.core_attention{mask_str}{high_prec_str}.onnx"
1062
1063
1064

    if attn_mask_type is None:
        attn_mask_type = 'causal'
1065
        input_names = ["query", "key", "value"]
1066
        inp = (query_layer, key_layer, value_layer)
1067
    model = te.attention.DotProductAttention(
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
        num_attention_heads=num_attention_heads,
        kv_channels=kv_channels,
        attention_dropout=0.5,
        attn_mask_type=attn_mask_type,
    ).to(device='cuda')
    do_export(model,
            inp,
            fname,
            input_names=input_names,
            use_fp8=True)
1078
1079
1080
1081
1082
    te_outputs = te_infer(model, inp, is_fp8=True)
    serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
    if precision in (torch.bfloat16, ):
        return
    validate_result(fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs)
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096


test_configs_multihead_attention = [
    #"use_mask, attn_mask_type"
    (False,    "causal"),  # calls ScaledUpperTriangMaskedSoftmax
    (True,     "padding"), # calls ScaledMaskedSoftmax
    (False,    "padding"), # calls ScaledSoftmax
]
test_configs_attention_type = [
    #"input_layernorm, attention_type, fuse_qkv_params"
    (True,             "self",         True),
    (False,            "self",         True),
    (True,             "self",         False),
    (False,            "self",         False),
Neta Zmora's avatar
Neta Zmora committed
1097
1098
    (True,             "cross",        True),
    (False,            "cross",        True),
1099
    (True,             "cross",        False),
Neta Zmora's avatar
Neta Zmora committed
1100
    (False,            "cross",        False),
1101
1102
1103
]
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
1104
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
1105
1106
1107
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize("input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type)
def test_export_multihead_attention(
1108
1109
    seed_default_rng,
    set_max_seq_len,
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
    use_fp8: bool,
    use_mask: bool,
    attn_mask_type: str,
    precision: torch.dtype,
    return_layernorm_output: bool,
    input_layernorm: bool,
    attention_type: str,
    fuse_qkv_params: bool
):
    # Skip FP8 tests on non-hopper devices
1120
1121
    if use_fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140

    hidden_size = 256
    sequence_length = 128
    batch_size = 4
    num_attention_heads = 32
    kv_channels = 8
    attention_dropout = 0.1
    layernorm_epsilon = 1e-5
    init_method = output_layer_init_method = get_default_init_method()
    attention_args = (
        hidden_size,
        num_attention_heads,
        kv_channels,
        attention_dropout,
        layernorm_epsilon,
        init_method,
        output_layer_init_method,
    )

1141
    hidden_states_context = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
1142
1143
1144
1145
1146
1147
1148
    attention_mask = None
    if use_mask and attn_mask_type != "causal":
        # Generate a random mask with 50% probability for 0 or 1.
        probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
        attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)

    encoder_output = None
1149

1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
    if attention_type == "cross":
        encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")

    fp8_str = "_fp8" if use_fp8 else ""
    dtype_str = dtype2str(precision)
    attn_type_str = "_self-attention" if attention_type == "self" else "_cross-attention"
    fuse_qkv_str = "_fused-qkv" if fuse_qkv_params else ""
    attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
    input_ln_str = "_input-ln" if input_layernorm else ""
    fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx"

1161
    model = te.attention.MultiHeadAttention(
1162
1163
1164
1165
1166
1167
1168
1169
        *attention_args,
        attn_mask_type=attn_mask_type,
        params_dtype=precision,
        return_layernorm_output=return_layernorm_output,
        input_layernorm=input_layernorm,
        attention_type=attention_type,
        fuse_qkv_params=fuse_qkv_params,
    ).to(device='cuda')
1170
1171
1172
1173
1174
1175
1176

    inp_context = (hidden_states_context, attention_mask, encoder_output)
    input_names = ["hidden_states", "attention_mask", "encoder_output"]
    output_names=["attention_output", "attention_bias"]
    do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names,
        dynamic_axes={"hidden_states": {0: "seq", 1:"bs"},
                      "attention_output": {0: "seq", 1:"bs"}})
1177
1178
1179
1180
1181
    te_outputs = te_infer(model, inp_context, is_fp8=use_fp8)
    serialize_inputs_outputs(fname, inp_context, te_outputs, input_names=input_names, output_names=output_names)
    if precision in (torch.bfloat16, ):
        return

1182
    if not use_fp8:
1183
1184
        validate_result(fname, inp_context, model, atol=1e-3, input_names=input_names,
            output_names=output_names, te_outputs=te_outputs)
1185
    else:
1186
        validate_result(fname, inp_context, model, atol=1e-2, is_fp8=use_fp8,
1187
1188
            input_names=input_names, output_names=output_names, allow_cnt_errors=3,
            te_outputs=te_outputs)
1189

1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
    # In GPT generative phase (inference) the input sequence is smaller than the maximum
    # allowed sequence length and we want to test this condition.
    # Pretend that we're in generative phase when it makes sense (causal mask and self-attention).
    is_generative_phase = (attn_mask_type == "causal" and attention_type == "self")
    if is_generative_phase:
        seq_len_offset = 8
        hidden_states_generative = torch.randn(sequence_length-seq_len_offset, batch_size, hidden_size, dtype=precision, device="cuda")
        inp_generative = (hidden_states_generative, attention_mask, encoder_output)
        if not use_fp8:
            validate_result(fname, inp_generative, model, atol=1e-3, input_names=input_names, output_names=output_names)
        else:
            validate_result(fname, inp_generative, model, atol=1e-2, is_fp8=use_fp8,
                input_names=input_names, output_names=output_names, allow_cnt_errors=3)



1206
1207
1208
1209
1210
1211
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("output_layernorm", [
    #True, # TO DO: handle this
    False
])
1212
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
1213
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
1214
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
1215
def test_export_transformer_layer(
1216
1217
    seed_default_rng,
    set_max_seq_len,
1218
1219
1220
1221
1222
1223
    use_fp8: bool,
    use_mask: bool,
    attn_mask_type: str,
    output_layernorm: bool,
    precision: torch.dtype,
    fuse_qkv_params: bool,
1224
    zero_centered_gamma: bool
1225
1226
):
    # Skip FP8 tests on non-hopper devices
1227
1228
    if use_fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
1229
1230
1231
1232
1233
1234
1235
1236
1237

    # Layer configuration
    hidden_size = 64
    sequence_length = 128
    batch_size = 1
    ffn_hidden_size = 256
    num_attention_heads = 4

    input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
1238
    input_names = ["input", "attention_mask"]
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
    attention_mask = None
    if use_mask and attn_mask_type != "causal":
        # Generate a random mask with 50% probability for 0 or 1.
        probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
        attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
    inp = (input_tensor, attention_mask)

    fp8_str = "_fp8" if use_fp8 else ""
    fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
    high_prec_str = dtype2str(precision)
    attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
1250
    fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}.onnx"
1251
1252
1253
1254
1255
1256
1257
1258
1259

    model = te.TransformerLayer(
        hidden_size,
        ffn_hidden_size,
        num_attention_heads,
        self_attn_mask_type=attn_mask_type,
        output_layernorm=output_layernorm,
        params_dtype=precision,
        fuse_qkv_params=fuse_qkv_params,
1260
        zero_centered_gamma=zero_centered_gamma).to(device='cuda')
1261
    do_export(model, inp, fname, use_fp8, input_names=input_names)
1262
1263
1264
1265
    te_outputs = te_infer(model, inp, is_fp8=use_fp8)
    serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
    if precision in (torch.bfloat16, ):
        return
1266
    if not use_fp8:
1267
1268
        validate_result(fname, inp, model, atol=1e-3, input_names=input_names,
            te_outputs=te_outputs)
1269
    else:
1270
1271
        validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8, input_names=input_names,
            te_outputs=te_outputs)
1272

Neta Zmora's avatar
Neta Zmora committed
1273
1274
1275
1276
1277
1278
1279

@pytest.mark.parametrize("use_fp8", [True])
@pytest.mark.parametrize("ln_scale_factor", [448*2])
@pytest.mark.parametrize("gemm_scale_factors", [(224, 224,),])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_gemm_layernorm(
1280
    seed_default_rng,
Neta Zmora's avatar
Neta Zmora committed
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
    use_fp8: bool,
    ln_scale_factor: float,
    gemm_scale_factors: Tuple[float, float],
    precision: torch.dtype,
    zero_centered_gamma: bool
):
    """This is a regression test for testing that all LN inputs have the same type.

    The test sets up GEMM with FP32 output which feeds into an LN that is configured
    with FP16 or BF16 weights and bias.
    """

    # Skip FP8 tests on non-hopper devices
    if use_fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
    class TestFP8_GemmLayernorm(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            normalized_shape = torch.Size(inp.shape[1:])
            self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda")
            self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda")
            self.eps = 1e-6 # An arbitrary small value

            self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
            self.meta = create_meta(ln_scale_factor)
            self.fp8_type = tex.DType.kFloat8E4M3
            self.gemm = TestFP8_GEMM(
                precision, use_bias=False, gelu=False, scale_factors=gemm_scale_factors)

        def forward(self, inp, weight):
            x = self.gemm(inp, weight)
            x = texcpp.layernorm_fwd_fp8_inf(
                x,
                self.weight,
                self.bias,
                self.eps,
                self.meta,
                self.fp8_tensor,
                self.fp8_type,
                zero_centered_gamma)

            x = cast_from_fp8(
                x,
                self.meta,
                self.fp8_tensor,
                self.fp8_type,
                tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16)
            return x

    out_features = 128
    hidden_size = 128
    in_features = 128
    class TestFP8_GEMM(nn.Module):
        def __init__(self, precision, use_bias, gelu, scale_factors):
            super().__init__()
            self.use_bias = use_bias
            self.gelu = gelu
            self.precision = precision

            self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
            self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
            nb_inp_scales, nb_weight_scales = 1, out_features
            act_scale_factor, weight_scale_factor = scale_factors
            self.meta_inp = create_meta(act_scale_factor, nb_inp_scales)
            self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales)

            bias_size = nb_weight_scales
            self.bias = torch.randn(bias_size, dtype=precision, device="cuda")
            self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda")

            self.inp_type = tex.DType.kFloat8E4M3
            self.weights_type = tex.DType.kFloat8E4M3
            self.outp_type = precision

        def forward(self, inp, weight):
            inp_fp8 = cast_to_fp8(
                inp,
                self.meta_inp,
                self.fp8_tensor_inp,
                self.inp_type)

            weight_fp8 = cast_to_fp8(
                weight,
                self.meta_weight,
                self.fp8_tensor_weight,
                self.weights_type)

            ret = fp8_gemm(
                weight_fp8,
                self.meta_weight.scale_inv,
                self.fp8_tensor_weight,
                self.inp_type,
                inp_fp8,
                self.meta_inp.scale_inv,
                self.fp8_tensor_inp,
                self.weights_type,
                self.outp_type,
                get_workspace(),
                bias=self.bias,
                use_bias=self.use_bias,
                use_split_accumulator=False)
            return ret

    inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda")
    weight = torch.randn(out_features, in_features, dtype=precision, device="cuda")
    model = TestFP8_GemmLayernorm()
    high_prec_str = dtype2str(precision)
    fp8_str = f"_fp8" if use_fp8 else ""
    fname = f"te.gemm_layernorm{fp8_str}{high_prec_str}.onnx"
1390
1391
    input_names = ['input', 'weight']
    do_export(model, (inp, weight), fname, use_fp8=use_fp8, input_names=input_names)
1392
1393
    te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
    serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
Neta Zmora's avatar
Neta Zmora committed
1394
1395
    if precision not in (torch.bfloat16, ):
        validate_result(
1396
1397
            fname, (inp, weight), model, atol=5e-2, is_fp8=use_fp8, allow_cnt_errors=2,
            input_names=input_names, te_outputs=te_outputs)
Neta Zmora's avatar
Neta Zmora committed
1398
1399


1400
1401
@skip_FP8
@pytest.mark.parametrize("use_fp8", [True, False])
1402
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16])
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
@pytest.mark.parametrize("zero_centered_gamma", [True])
def test_export_gpt_generation(
    seed_default_rng,
    set_max_seq_len,
    use_fp8: bool,
    precision: torch.dtype,
    zero_centered_gamma: bool
):
    """Test that the ONNX model can correctly handle inputs with different shapes and that
    the attention mask it adjusted on-the-fly to different sequence lengths.
    """

    # Skip FP8 tests on non-hopper devices
    if use_fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)

    # Layer configuration
    hidden_size = 64
    sequence_length = 128
    batch_size = 1
    ffn_hidden_size = 256
    num_attention_heads = 4
    attention_mask = None
    use_mask = True
    attn_mask_type = "causal"
    fuse_qkv_params = True
    output_layernorm = False

    fp8_str = "_fp8" if use_fp8 else ""
    fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
    high_prec_str = dtype2str(precision)
    attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
    fname = f"te.transformer_layer_generative{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}.onnx"

    model = te.TransformerLayer(
        hidden_size,
        ffn_hidden_size,
        num_attention_heads,
        self_attn_mask_type=attn_mask_type,
        output_layernorm=output_layernorm,
        params_dtype=precision,
        fuse_qkv_params=fuse_qkv_params,
        zero_centered_gamma=zero_centered_gamma).to(device='cuda')

    # "Context phase": use full input sequence length
    input_names = ["input"]
    output_names = ["output"]
    input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
    inp = (input_tensor,)
    do_export(model, inp, fname, use_fp8,
        input_names=input_names, output_names=output_names,
        dynamic_axes={"input": {0: "seq", 1:"bs"},
                      "output": {0: "seq", 1:"bs"}, })
1456
1457
1458
1459
1460
    te_outputs = te_infer(model, inp, is_fp8=use_fp8)
    serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names, output_names=output_names)
    if precision not in (torch.bfloat16, ):
        validate_result(fname, inp, model, atol=6e-3, is_fp8=use_fp8, input_names=input_names,
            te_outputs=te_outputs)
1461
1462
1463
1464
1465

    # "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8.
    sequence_length = 1 if not use_fp8 else 8
    input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
    inp = (input_tensor, attention_mask)
1466
1467
1468
1469
1470
    te_outputs = te_infer(model, inp, is_fp8=use_fp8)
    serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
    if precision not in (torch.bfloat16, ):
        validate_result(fname, inp, model, atol=6e-3, is_fp8=use_fp8, input_names=input_names,
            te_outputs=te_outputs)
1471
1472


1473
1474
1475
1476
1477
1478
@pytest.mark.parametrize("enabled", [True, False])
def test_export_ctx_manager(enabled):
    assert is_in_onnx_export_mode() == False
    with te.onnx_export(enabled):
        assert is_in_onnx_export_mode() == enabled
    assert is_in_onnx_export_mode() == False