Unverified Commit 6c9ce179 authored by asfiyab-nvidia's avatar asfiyab-nvidia Committed by GitHub
Browse files

Add ONNX export support for TE modules (#41)



* Add ONNX export support for TE modules (#1)

* Add TorchScript Operators
* Add symbolic methods to ONNX exporter
* Add tests for the ONNX export
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fixes for pylint tests
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix pylint warning in softmax.py
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* move FP8 ORT lib inside tests/
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* enable cross attention tests
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* refactor code by @nzmora
* Increase layernorm FP16 threshold
* Normalize onnx file names: _ separates configs; - separates words in a single config
* Add get_attn_mask_str and fix mask string
* Add missing ONNX files
* Moved generated ONNX files to tests/gen_onnx_models/
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix merge conflict changes
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix Q/DQ scale input
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* enable FP16 config when bias is disabled
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix pylint check errors
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* updates
1. remove List import for pylint failure
2. address comments: remove state tensors from GPU
3. address comments: Update reverse_map_dtype function and add to namespace
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* minor fix: coding guidelines
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* changes:
1. skip FP8 tests on  non-hopper devices
2. minor fix for C++ lint check
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix onnxruntime version
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* minor fix: add space between code and comment
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* changes
1. update copyrights
2. update path to ORT .so
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarasfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com>
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
Signed-off-by: default avatarasfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e2ad34e9
......@@ -8,6 +8,7 @@
*.nsys-rep
*.ncu-rep
*.sqlite
*.onnx
.eggs
build/
*.so
......
......@@ -6,5 +6,6 @@ set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5
pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/test_transformerengine.py
pytest -v -s $TE_PATH/tests/test_onnx_export.py
......@@ -14,7 +14,7 @@ from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext
from distutils.version import LooseVersion
from distutils.file_util import copy_file
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
path = os.path.dirname(os.path.realpath(__file__))
......@@ -85,6 +85,7 @@ include_dirs = make_abs_path(include_dirs)
pytorch_sources = [
"transformer_engine/pytorch/csrc/extensions.cu",
"transformer_engine/pytorch/csrc/common.cu",
"transformer_engine/pytorch/csrc/ts_fp8_op.cpp",
]
pytorch_sources = make_abs_path(pytorch_sources)
......
This diff is collapsed.
......@@ -10,3 +10,12 @@ from .module import LayerNorm
from .transformer import TransformerLayer
from .fp8 import fp8_autocast
from .distributed import checkpoint
# Register custom op symbolic ONNX functions
from .te_onnx_extensions import (
onnx_cast_to_fp8,
onnx_cast_from_fp8,
onnx_fp8_gelu,
onnx_te_gemm,
onnx_layernorm_fwd_fp8,
onnx_layernorm_fwd,
)
......@@ -12,9 +12,11 @@ from .constants import TE_DType
def fp8_gemm(
A: torch.Tensor,
A_scale_inv: torch.Tensor,
A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
A_dtype: tex.DType,
B: torch.Tensor,
B_scale_inv: torch.Tensor,
B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
B_dtype: tex.DType,
out_dtype: torch.dtype,
workspace: torch.Tensor,
......@@ -41,19 +43,21 @@ def fp8_gemm(
out_dtype = tex.DType.kFloat32 if fp32_output else TE_DType[out_dtype]
tex.te_gemm(
_ = torch.ops.tex_ts.te_gemm_ts(
A,
A_scale_inv,
A_fp8_tensor,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor,
B_dtype,
False, # transb
out,
out_dtype,
bias if use_bias else empty_tensor,
empty_tensor,
empty_tensor, # this is pre_gelu_out
False, # grad
workspace,
workspace.shape[0],
......@@ -87,6 +91,7 @@ def gemm(
transa = layout[0] == "T"
transb = layout[1] == "T"
empty_tensor = torch.Tensor()
fp8_index = -1 # dummy index
input_dtype = TE_DType[dtype]
output_dtype = tex.DType.kFloat32 if fp32_output else input_dtype
......@@ -115,13 +120,15 @@ def gemm(
bias = bias if use_bias else empty_tensor
tex.te_gemm(
_ = torch.ops.tex_ts.te_gemm_ts(
A,
empty_tensor,
fp8_index,
input_dtype,
transa,
B,
empty_tensor,
fp8_index,
input_dtype,
transb,
out,
......@@ -214,11 +221,12 @@ def fp8_gelu(
otype: tex.DType,
) -> torch.Tensor:
"""GeLU with FP8 output"""
return tex.fp8_gelu(
return torch.ops.tex_ts.fp8_gelu_ts(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
)
......@@ -245,6 +253,48 @@ def layernorm_fwd_fp8(
)
def layernorm_fwd_fp8_inf(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""LayerNorm with FP8 output.
This version of layernorm_fwd_fp8 is specialized for inference, and returns
only the normalized output.
"""
ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype)
return ret
def layernorm_fwd_inf(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
) -> torch.Tensor:
"""LayerNorm with FP8 output"""
return torch.ops.tex_ts.layernorm_fwd_inf_ts(
inp,
weight,
bias,
eps,
)
def cast_to_fp8(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
......@@ -252,11 +302,12 @@ def cast_to_fp8(
otype: tex.DType,
) -> torch.Tensor:
"""Cast input to FP8"""
return tex.cast_to_fp8(
return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
)
......@@ -269,9 +320,10 @@ def cast_from_fp8(
otype: tex.DType,
) -> torch.Tensor:
"""Cast input from FP8"""
return tex.cast_from_fp8(
return torch.ops.tex_ts.cast_from_fp8_ts(
inp,
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.scale_inv,
fp8_tensor,
itype,
otype,
)
......@@ -94,6 +94,8 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
return transformer_engine::DType::kFloat32;
case at::kBFloat16:
return transformer_engine::DType::kBFloat16;
case at::kBool:
return transformer_engine::DType::kByte;
default:
NVTE_ERROR("Invalid type");
}
......
......@@ -397,6 +397,23 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
}
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype);
return out[0];
}
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
......@@ -428,6 +445,16 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
return {ln_out, mu, rsigma};
}
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps
) {
// This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps);
return out[0];
}
at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale,
......
......@@ -95,6 +95,15 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
transformer_engine::DType otype
);
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight,
......@@ -102,6 +111,11 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
float eps
);
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps
);
at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale,
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <torch/script.h>
#include "extensions.h"
namespace {
transformer_engine::DType reverse_map_dtype(int64_t dtype) {
if (dtype >= 0 && dtype < static_cast<int64_t>(transformer_engine::DType::kNumTypes)) {
return static_cast<transformer_engine::DType>(dtype);
} else {
NVTE_ERROR("Type not supported.");
}
}
} // namespace
at::Tensor cast_to_fp8_ts(const at::Tensor &input,
const at::Tensor &scale,
const at::Tensor &amax,
const at::Tensor &scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = cast_to_fp8(input,
scale[fp8_tensor],
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
otype_arg);
return output.clone();
}
at::Tensor cast_from_fp8_ts(const at::Tensor &input,
const at::Tensor &scale_inv,
int64_t fp8_tensor,
int64_t itype,
int64_t otype) {
transformer_engine::DType itype_arg = reverse_map_dtype(itype);
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = cast_from_fp8(input,
scale_inv[fp8_tensor],
itype_arg,
otype_arg);
return output.clone();
}
at::Tensor fp8_gelu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = fp8_gelu(input,
scale[fp8_tensor],
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
otype_arg);
return output.clone();
}
at::Tensor te_gemm_ts(at::Tensor A,
at::Tensor A_scale_inverse,
int64_t A_fp8_tensor,
int64_t A_type,
int64_t transa,
at::Tensor B,
at::Tensor B_scale_inverse,
int64_t B_fp8_tensor,
int64_t B_type,
int64_t transb,
at::Tensor D,
int64_t D_type,
at::Tensor bias,
at::Tensor pre_gelu_out,
int64_t grad,
at::Tensor workspace,
int64_t workspaceSize,
int64_t accumulate,
int64_t use_split_accumulator) {
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);
at::Tensor A_scale_inverse_arg = A_scale_inverse.clone();
if (A_scale_inverse.numel())
A_scale_inverse_arg = A_scale_inverse[A_fp8_tensor];
at::Tensor B_scale_inverse_arg = B_scale_inverse.clone();
if (B_scale_inverse.numel())
B_scale_inverse_arg = B_scale_inverse[B_fp8_tensor];
te_gemm(A,
A_scale_inverse_arg,
A_type_arg,
transa_arg,
B,
B_scale_inverse_arg,
B_type_arg,
transb_arg,
D,
D_type_arg,
bias,
pre_gelu_out,
grad_arg,
workspace,
workspaceSize_arg,
accumulate_arg,
use_split_accumulator_arg);
return D;
}
at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
double eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps);
at::Tensor output = layernorm_fwd_fp8_inf(input,
weight,
bias,
eps_float,
scale,
amax,
scale_inv,
otype_arg);
return output.clone();
}
at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
double eps) {
float eps_float = static_cast<float>(eps);
at::Tensor output = layernorm_fwd_inf(input,
weight,
bias,
eps_float);
return output.clone();
}
TORCH_LIBRARY(tex_ts, m) {
m.def("cast_to_fp8_ts", &cast_to_fp8_ts);
m.def("cast_from_fp8_ts", &cast_from_fp8_ts);
m.def("fp8_gelu_ts", &fp8_gelu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
}
This diff is collapsed.
......@@ -5,9 +5,10 @@
"""Fused scaled masked softmax functions"""
import os
from typing import Callable, Tuple, Union
import torch
from torch import nn
import torch._C._onnx as _C_onnx
from torch.onnx import _type_utils
import transformer_engine_extensions as tex
......@@ -46,6 +47,36 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
return input_grads, None
@staticmethod
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledUpperTriangMaskedSoftmax symbolic method"""
def triangular_mask():
dtype = _type_utils.JitScalarType.INT64
ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype)
k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
mask = g.op("Trilu", ones, k, upper_i=1)
mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL)
return mask
# Captures the logic of function scaled_upper_triang_masked_softmax_warp_forward
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
mask = triangular_mask()
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
inv_mask = g.op("Sub", one, mask)
neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
softmax_mask = g.op("Mul", mask, neg_tenK)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
......@@ -78,6 +109,35 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
)
return input_grads, None, None
@staticmethod
def symbolic(
g: torch.Graph,
inputs: torch._C.Value,
mask: torch._C.Value,
scale: float) -> torch._C.Value:
"""ScaledMaskedSoftmax symbolic method"""
# Captures the logic of function scaled_masked_softmax_warp_forward.
# output = softmax(mask(input*scale)
# Computed as:
# masked_scaled = (1 - mask)*(input*scale)
# softmax_mask = mask * -10000
# output = softmax(masked_scaled + softmax_mask)
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
inv_mask = g.op("Sub", one, mask)
# Note: type is hard coded because softmax uses FP16 or BF16
neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
softmax_mask = g.op("Mul", mask, neg_tenK)
masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
class ScaledSoftmax(torch.autograd.Function):
"""
......@@ -107,6 +167,19 @@ class ScaledSoftmax(torch.autograd.Function):
)
return input_grads, None, None
@staticmethod
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledSoftmax symbolic method"""
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
out = g.op("Softmax", scaled)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
class FusedScaleMaskSoftmax(nn.Module):
"""
......@@ -163,7 +236,7 @@ class FusedScaleMaskSoftmax(nn.Module):
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(sk)
batch_per_block = self.get_batch_per_block(int(sk))
if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0:
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
ONNX symbolic functions for Transformer Engine
Warnings of the type pasted below are a known Pytorch issue
(https://github.com/pytorch/pytorch/issues/81693):
tests/test_onnx_export.py::test_export_cast_ops[112]
/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py:649:
UserWarning: The shape inference of trt::TRT_FP8DequantizeLinear type is missing,
so it may result in wrong shape inference for the exported graph.
Please consider adding it in symbolic function. (Triggered internally at
/opt/pytorch/pytorch/torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1880.)
_C._jit_pass_onnx_graph_shape_type_inference(
Scale tensors are treated as lists ("fs") instead of tensors ("v") because we need to access
specific entries using the index passes as `fp8_tensor`. If you fail to do this you will get
the following error when accessing a sepcific scale element (e.g. `scale_inv[fp8_tensor]`):
TypeError: 'torch._C.Value' object is not subscriptable
"""
import torch
from torch.onnx import symbolic_helper, register_custom_op_symbolic
import torch._C._onnx as _C_onnx
import transformer_engine_extensions as tex
# This file registers custom op symbolic ONNX functions and does not export any symbols.
__all__ = []
# Custom ops spec version
VER = 1
UNSPECIFIED_TYPE = -1
def make_op_name(op_name: str) -> str:
"""custom op name"""
return "trt::" + op_name
def quantize(g, inputs, scale_inv, fp8_tensor):
"""Helper Function for Quantization"""
output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
# Q inputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the input if needed.
if inputs.type().scalarType() == "Half":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT)
scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor]))
q_op = g.op(
make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType(
inputs.type().with_dtype(torch.uint8).with_sizes(output_shape))
return q_op
def dequantize(g, inputs, scale_inv, fp8_tensor, otype):
"""Helper Function for Dequantization"""
output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor]))
out = g.op(make_op_name("TRT_FP8DequantizeLinear"), inputs, scale).setType(
inputs.type().with_dtype(torch.float32).with_sizes(output_shape))
# DQ outputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the output if needed.
if otype == int(tex.DType.kFloat16):
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for cast_to_fp8"""
# pylint: disable=unused-argument
return quantize(g, inputs, scale_inv, fp8_tensor)
@symbolic_helper.parse_args("v", "fs", "i", "i", "i")
def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype):
"""ONNX graph for cast_from_fp8"""
# pylint: disable=unused-argument
return dequantize(g, inputs, scale_inv, fp8_tensor, otype)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_gelu"""
# pylint: disable=unused-argument
gelu = torch.onnx.symbolic_opset9.gelu(g, inputs, "tanh")
out = quantize(g, gelu, scale_inv, fp8_tensor)
return out
@symbolic_helper.parse_args("v", "fs", "i", "i", "i",
"v", "fs", "i", "i", "i",
"v", "i", "v", "v", "i",
"v", "i", "i", "i")
def onnx_te_gemm(
g,
weight,
weight_scale_inverse,
weight_fp8_tensor,
weight_type,
trans_weight,
inputs,
input_scale_inverse,
input_fp8_tensor,
input_type,
trans_input,
out,
out_type,
bias,
pre_gelu_out,
grad,
workspace,
workspaceSize,
accumulate,
use_split_accumulator):
"""ONNX graph for te_gemm"""
# pylint: disable=unused-argument
is_fp16 = bias.type().scalarType() == "Half"
if input_type == int(tex.DType.kFloat8E4M3):
inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, UNSPECIFIED_TYPE)
if weight_type == int(tex.DType.kFloat8E4M3):
weight = dequantize(g, weight, weight_scale_inverse, weight_fp8_tensor, UNSPECIFIED_TYPE)
output = g.op("Gemm", inputs, weight, transA_i=trans_input, transB_i=trans_weight)
empty_tensor_size = [0]
bias_empty = torch.onnx.symbolic_helper._get_tensor_sizes(bias) == empty_tensor_size
pre_gelu_out_empty = torch.onnx.symbolic_helper._get_tensor_sizes(pre_gelu_out) \
== empty_tensor_size
if not bias_empty:
if pre_gelu_out_empty:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
output = g.op('Add', output, bias)
else:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
output = g.op('Add', output, bias)
output = torch.onnx.symbolic_opset9.gelu(g, output)
else:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return output
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i")
def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for layernorm_fwd_fp8"""
# pylint: disable=unused-argument
ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
return fp8_ln
@symbolic_helper.parse_args("v", "v", "v", "f")
def onnx_layernorm_fwd(g, inputs, weight, bias, eps):
"""ONNX graph for layernorm_fwd"""
# pylint: disable=unused-argument
normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
if normalized_shape is None:
ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs)
assert ndim is not None
normalized_shape = list(range(0, ndim))
# Normalization axis = 0, so normalized_shape uses all dims except dim = 0
normalized_shape = normalized_shape[1:]
ln = torch.onnx.symbolic_opset9.layer_norm(
g,
inputs,
normalized_shape,
weight,
bias,
eps,
False # cudnn_enable (not relevant)
)
return ln
register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER)
register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER)
register_custom_op_symbolic('tex_ts::fp8_gelu_ts', onnx_fp8_gelu, VER)
register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment