onnx_extensions.py 14.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""

File containing torch.ops extensions and their corresponding ONNX symbolic functions.

Many transformer engine layers rely on custom calls from the transformer_engine_torch module, making ONNX export challenging because:
1. They often accept Python objects (quantizers), which ONNX does not support.
2. They are complex, incorporating fusions and precomputing certain values for backward passes—mechanisms unnecessary for ONNX export.

For these reasons, we introduce onnx_forward methods in each layer that are simpler and
primarily leverage torch operators with known ONNX symbolic functions.
These methods avoid fusions and backward pass precomputations.
The main considerations are quantization—which PyTorch does not natively support, so we need to implement onnx symbolic functions on our own.

Since ONNX does not yet support quantization, operators from TensorRT are employed.
The primary goal of ONNX export is to enable inference compatibility with TensorRT.

"""

from typing import Tuple
import math
import torch
import onnxscript
from onnxscript import opset18 as op
from onnx import defs
import transformer_engine_torch as tex

from .tensor.float8_tensor import Float8Quantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer
from .constants import MXFP8_BLOCK_SCALING_SIZE
from .utils import round_up_to_nearest_multiple
from .export import is_in_onnx_export_mode

trt_opset = onnxscript.values.Opset(
    "trt", version=1
)  # opset from TensorRT which supports FP8 quantization

# ONNX GEMM for inference


def onnx_gemm(weight: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
    """ONNX GEMM used for inference."""
    reshaped_inp = inp.reshape(-1, inp.shape[-1])
    out = torch_onnx_gemm_inf_op(weight, reshaped_inp, bias)
    return out.reshape(inp.shape[:-1] + (-1,))


@torch.library.custom_op("tex::gemm_inf", mutates_args=[])
def torch_onnx_gemm_inf_op(
    weight: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
    """Gemm used for inference -- weight is transposed"""
    out = inp @ weight.T
    if bias is not None:
        out = out + bias
    return out


@torch_onnx_gemm_inf_op.register_fake
def _(weight, inp, bias):
    """Fake gemm used for inference."""
    out = inp @ weight.T
    if bias is not None:
        out = out + bias
    return out


def onnx_gemm_inf_symbolic(
    weight: onnxscript.onnx_types.TensorType,
    inp: onnxscript.onnx_types.TensorType,
    bias: onnxscript.onnx_types.TensorType,
) -> onnxscript.onnx_types.TensorType:
    """Symbolic gemm used for inference."""
    return op.Gemm(inp, weight, bias, transA=0, transB=1)


# ONNX FP8 Quantization


@torch.library.custom_op("tex::fp8_quantize", mutates_args=[])
def onnx_quantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor:
    """Quantize to Float8Tensor used for inference."""
    scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device)
    amax_tensor = torch.tensor([1], dtype=torch.float32, device=tensor.device)
    quantizer = Float8Quantizer(scale_tensor, amax_tensor, tex.DType.kFloat8E4M3)
    return quantizer.quantize(tensor)._data


@onnx_quantize_fp8_op.register_fake
def _(tensor, *_):
    """Fake quantize to Float8Tensor used for inference."""
    return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device)


def onnx_quantize_fp8_symbolic(
    tensor: onnxscript.onnx_types.TensorType,
    scale: float,
) -> onnxscript.onnx_types.UINT8:
    """Symbolic quantize used for inference."""
    scale_inv = op.Constant(value_float=1 / scale)
    return TRT_FP8QuantizeLinear(tensor, scale_inv)


# Define the schema for the custom operator
schema = defs.OpSchema(
    name="TRT_FP8QuantizeLinear",
    domain="trt",
    since_version=1,
    doc="TRT FP8 Quantize Linear used for inference.",
    inputs=[
        defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"),
115
116
117
        defs.OpSchema.FormalParameter(
            "scale_inv", "tensor(float)", "Inverse scale factor for quantization"
        ),
118
119
120
121
122
123
124
125
126
127
128
129
130
    ],
    outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")],
)

TRT_FP8QuantizeLinear = onnxscript.values.Op(
    opset=trt_opset, name="TRT_FP8QuantizeLinear", op_schema=schema
)


# ONNX FP8 Dequantization


@torch.library.custom_op("tex::fp8_dequantize", mutates_args=[])
131
def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor:
132
133
    """Dequantize from Float8Tensor used for inference."""
    quantizer = Float8Quantizer(
134
        1 / scale_inv, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3
135
136
137
138
139
140
141
142
143
144
145
146
    )
    quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32)
    return quantizer_tensor.dequantize()


@onnx_dequantize_fp8_op.register_fake
def _(tensor: torch.Tensor, _) -> torch.Tensor:
    """Fake dequantize from Float8Tensor used for inference."""
    return torch.empty(tensor.shape, dtype=torch.float32, device=tensor.device)


def onnx_dequantize_fp8_symbolic(
147
    tensor: onnxscript.onnx_types.TensorType, scale_inv: onnxscript.onnx_types.TensorType
148
149
150
151
152
153
154
155
156
157
158
159
) -> onnxscript.onnx_types.TensorType:
    """Symbolic dequantize from Float8Tensor used for inference."""
    return TRT_FP8DequantizeLinear(tensor, scale_inv)


schema = defs.OpSchema(
    name="TRT_FP8DequantizeLinear",
    domain="trt",
    since_version=1,
    doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.",
    inputs=[
        defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"),
160
161
162
        defs.OpSchema.FormalParameter(
            "scale_inv", "tensor(float)", "Inverse scale factor for dequantization"
        ),
163
164
165
166
167
168
169
170
    ],
    outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")],
)

TRT_FP8DequantizeLinear = onnxscript.values.Op(
    opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema
)

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# ONNX FP8 Current Scaling Quantization


@torch.library.custom_op("tex::fp8_cs_quantize", mutates_args=[])
def onnx_cs_quantize_fp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """Quantize to FP8 with current scaling; returns (uint8, scale_inv)."""
    if tensor.dtype != torch.float32:
        tensor = tensor.to(torch.float32)
    amax = tensor.abs().max()
    eps = torch.tensor(1e-12, dtype=torch.float32, device=tensor.device)
    amax = torch.maximum(amax, eps)
    fp8_max = torch.tensor(448, dtype=torch.float32, device=tensor.device)
    scale = fp8_max / amax
    q = torch.ops.tex.fp8_quantize(tensor, scale)
    scale_inv = 1 / scale
    return q, scale_inv


@onnx_cs_quantize_fp8_op.register_fake
def _(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device), torch.ones(
        1, dtype=torch.float32, device=tensor.device
    )


def onnx_quantize_fp8_cs_symbolic(
    tensor: onnxscript.onnx_types.TensorType,
):
    """Symbolic quantize with current scaling; computes scale_inv from tensor."""
    # scale_inv = 1 / max(abs(tensor))
    amax = op.ReduceMax(op.Abs(tensor), keepdims=0)
    eps = op.Constant(value_float=1.0e-12)
    amax = op.Max(amax, eps)
    scale_inv = op.Div(amax, op.Constant(value_float=448.0))
    q = TRT_FP8QuantizeLinear(tensor, scale_inv)
    return q, scale_inv

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235

# ONNX MXFP8 Quantization


@torch.library.custom_op("tex::mxfp8_quantize", mutates_args=[])
def onnx_quantize_mxfp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """Quantize to MXFP8Tensor used for inference."""
    quantizer = MXFP8Quantizer(tex.DType.kFloat8E4M3)
    quantized_tensor = quantizer(tensor)
    return quantized_tensor._rowwise_data, quantized_tensor._rowwise_scale_inv


@onnx_quantize_mxfp8_op.register_fake
def _(tensor: torch.Tensor):
    """Fake quantize to MXFP8Tensor used for inference."""
    mxfp8_scale_shape = [
        round_up_to_nearest_multiple(math.prod(tensor.shape[:-1]), 128),
        round_up_to_nearest_multiple(tensor.shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
    ]
    return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device), torch.empty(
        mxfp8_scale_shape, dtype=torch.uint8, device=tensor.device
    )


def onnx_quantize_mxfp8_symbolic(
    tensor: onnxscript.onnx_types.TensorType,
) -> Tuple[onnxscript.onnx_types.TensorType, onnxscript.onnx_types.TensorType]:
    """Symbolic quantize to MXFP8Tensor used for inference."""
236
    tensor_out, scale_inv_out = TRT_MXFP8DynamicQuantize(tensor)
237
238
239
240
    return tensor_out, scale_inv_out


schema = defs.OpSchema(
241
    name="TRT_MXFP8DynamicQuantize",
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    domain="trt",
    since_version=1,
    doc="TRT MXFP8 Quantize Linear used for inference.",
    inputs=[
        defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"),
    ],
    outputs=[
        defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor"),
        defs.OpSchema.FormalParameter(
            "scale_inv", "tensor(uint8)", "Scale factor for quantization"
        ),
    ],
)

256
257
TRT_MXFP8DynamicQuantize = onnxscript.values.Op(
    opset=trt_opset, name="TRT_MXFP8DynamicQuantize", op_schema=schema
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
)


# ONNX MXFP8 Dequantization


@torch.library.custom_op("tex::mxfp8_dequantize", mutates_args=[])
def onnx_dequantize_mxfp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor:
    """Dequantize from MXFP8Tensor used for inference."""
    quantizer = MXFP8Quantizer(tex.DType.kFloat8E4M3)
    quantizer_tensor = quantizer.create_tensor_from_data(
        tensor, scale_inv, fake_dtype=torch.float32
    )
    return quantizer_tensor.dequantize()


@onnx_dequantize_mxfp8_op.register_fake
def _(tensor: torch.Tensor, _):
    """Fake dequantize from MXFP8Tensor used for inference."""
    return torch.empty(tensor.shape, dtype=torch.float32, device=tensor.device)


def onnx_dequantize_mxfp8_symbolic(
    tensor: onnxscript.onnx_types.TensorType, scale_inv: onnxscript.onnx_types.TensorType
) -> onnxscript.onnx_types.TensorType:
    """Symbolic dequantize from MXFP8Tensor used for inference."""
    return TRT_MXFP8DequantizeLinear(tensor, scale_inv)


schema = defs.OpSchema(
    name="TRT_MXFP8DequantizeLinear",
    domain="trt",
    since_version=1,
    doc="TRT MXFP8 Dequantize Linear from MXFP8Tensor used for inference.",
    inputs=[
        defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"),
        defs.OpSchema.FormalParameter(
            "scale_inv", "tensor(uint8)", "Scale factor for dequantization"
        ),
    ],
    outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")],
)

TRT_MXFP8DequantizeLinear = onnxscript.values.Op(
    opset=trt_opset, name="TRT_MXFP8DequantizeLinear", op_schema=schema
)


# ONNX LayerNorm


@torch.library.custom_op("tex::layernorm", mutates_args=[])
def onnx_layernorm_op(
    inp: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float
) -> torch.Tensor:
    """ONNX LayerNorm used for inference."""
    model = tex.LayerNorm(inp.shape[1], eps=eps)
    model.weight.data = weight
    model.bias.data = bias
    return model(inp)


@onnx_layernorm_op.register_fake
def _(inp, *_):
    """Fake ONNX LayerNorm used for inference."""
    return inp


def onnx_layernorm_symbolic(
    inp: onnxscript.onnx_types.TensorType,
    weight: onnxscript.onnx_types.TensorType,
    bias: onnxscript.onnx_types.TensorType,
    eps: float,
) -> onnxscript.onnx_types.TensorType:
    """Symbolic ONNX LayerNorm used for inference."""
    return op.LayerNormalization(inp, weight, bias, epsilon=eps)


# onnx layernorm helper function - handles layernorm with quantization


def onnx_layernorm(
    inp: torch.Tensor,
    layer_norm_weight: torch.Tensor,
    layer_norm_bias: torch.Tensor,
    eps: float,
    normalization: str,
    zero_centered_gamma: bool,
    output_dtype: torch.dtype,
    return_layernorm_output: bool,
    input_quantizer,
) -> torch.Tensor:
    """ONNX LayerNorm used for inference."""
    ln_weight = layer_norm_weight if not zero_centered_gamma else layer_norm_weight + 1
    ln_weight = ln_weight.to(inp.dtype).to(torch.float32)
    inp = inp.to(torch.float32)
    layer_norm_bias = (
        layer_norm_bias.to(output_dtype).to(torch.float32) if layer_norm_bias is not None else None
    )

    if normalization == "RMSNorm":
        ln_out = torch.nn.functional.rms_norm(inp, inp.shape[-1:], ln_weight, eps)
    else:
        ln_out = torch.nn.functional.layer_norm(
            inp, inp.shape[-1:], ln_weight, layer_norm_bias, eps
        )
    ln_out_return = ln_out

    if input_quantizer is not None:
        if return_layernorm_output:
            # In case of return_layernorm_output, layernorm is not fused with fp8 cast,
            # so we cast to input_dtype and then perform cast to fp8 if needed
            ln_out = ln_out.to(output_dtype).to(torch.float32)
            ln_out_return = ln_out
        elif isinstance(input_quantizer, MXFP8Quantizer):
            # layernorm + mxfp8 quantizer behaves differently
            ln_out = ln_out.to(output_dtype).to(torch.float32)
        ln_out_quantized = input_quantizer.onnx_quantize(ln_out)
        ln_out = input_quantizer.onnx_dequantize(ln_out_quantized)
    ln_out = ln_out.to(output_dtype)
    return ln_out, ln_out_return


# utility functions


def onnx_attention_mask_func(
    attention_scores: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
    """Get attention mask without inp"""
    assert is_in_onnx_export_mode()
    return attention_scores.masked_fill(attention_mask, -10000.0)


# This translation table should be passed to torch.onnx.export function
# using the custom_translation_table=te_translation_table option.
te_translation_table = {
    torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic,
    torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic,
    torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic,
398
    torch.ops.tex.fp8_cs_quantize.default: onnx_quantize_fp8_cs_symbolic,
399
400
401
402
    torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic,
    torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic,
    torch.ops.tex.layernorm.default: onnx_layernorm_symbolic,
}