Unverified Commit 6c9ce179 authored by asfiyab-nvidia's avatar asfiyab-nvidia Committed by GitHub
Browse files

Add ONNX export support for TE modules (#41)



* Add ONNX export support for TE modules (#1)

* Add TorchScript Operators
* Add symbolic methods to ONNX exporter
* Add tests for the ONNX export
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fixes for pylint tests
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix pylint warning in softmax.py
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* move FP8 ORT lib inside tests/
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* enable cross attention tests
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* refactor code by @nzmora
* Increase layernorm FP16 threshold
* Normalize onnx file names: _ separates configs; - separates words in a single config
* Add get_attn_mask_str and fix mask string
* Add missing ONNX files
* Moved generated ONNX files to tests/gen_onnx_models/
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix merge conflict changes
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix Q/DQ scale input
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* enable FP16 config when bias is disabled
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix pylint check errors
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* updates
1. remove List import for pylint failure
2. address comments: remove state tensors from GPU
3. address comments: Update reverse_map_dtype function and add to namespace
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* minor fix: coding guidelines
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* changes:
1. skip FP8 tests on  non-hopper devices
2. minor fix for C++ lint check
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix onnxruntime version
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* minor fix: add space between code and comment
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* changes
1. update copyrights
2. update path to ORT .so
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarasfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com>
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
Signed-off-by: default avatarasfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e2ad34e9
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
*.nsys-rep *.nsys-rep
*.ncu-rep *.ncu-rep
*.sqlite *.sqlite
*.onnx
.eggs .eggs
build/ build/
*.so *.so
......
...@@ -6,5 +6,6 @@ set -e ...@@ -6,5 +6,6 @@ set -e
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/test_transformerengine.py pytest -v -s $TE_PATH/tests/test_transformerengine.py
pytest -v -s $TE_PATH/tests/test_onnx_export.py
...@@ -14,7 +14,7 @@ from setuptools import setup, find_packages, Extension ...@@ -14,7 +14,7 @@ from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
from distutils.version import LooseVersion from distutils.version import LooseVersion
from distutils.file_util import copy_file from distutils.file_util import copy_file
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
...@@ -85,6 +85,7 @@ include_dirs = make_abs_path(include_dirs) ...@@ -85,6 +85,7 @@ include_dirs = make_abs_path(include_dirs)
pytorch_sources = [ pytorch_sources = [
"transformer_engine/pytorch/csrc/extensions.cu", "transformer_engine/pytorch/csrc/extensions.cu",
"transformer_engine/pytorch/csrc/common.cu", "transformer_engine/pytorch/csrc/common.cu",
"transformer_engine/pytorch/csrc/ts_fp8_op.cpp",
] ]
pytorch_sources = make_abs_path(pytorch_sources) pytorch_sources = make_abs_path(pytorch_sources)
......
This diff is collapsed.
...@@ -10,3 +10,12 @@ from .module import LayerNorm ...@@ -10,3 +10,12 @@ from .module import LayerNorm
from .transformer import TransformerLayer from .transformer import TransformerLayer
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
from .distributed import checkpoint from .distributed import checkpoint
# Register custom op symbolic ONNX functions
from .te_onnx_extensions import (
onnx_cast_to_fp8,
onnx_cast_from_fp8,
onnx_fp8_gelu,
onnx_te_gemm,
onnx_layernorm_fwd_fp8,
onnx_layernorm_fwd,
)
...@@ -12,9 +12,11 @@ from .constants import TE_DType ...@@ -12,9 +12,11 @@ from .constants import TE_DType
def fp8_gemm( def fp8_gemm(
A: torch.Tensor, A: torch.Tensor,
A_scale_inv: torch.Tensor, A_scale_inv: torch.Tensor,
A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
A_dtype: tex.DType, A_dtype: tex.DType,
B: torch.Tensor, B: torch.Tensor,
B_scale_inv: torch.Tensor, B_scale_inv: torch.Tensor,
B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
B_dtype: tex.DType, B_dtype: tex.DType,
out_dtype: torch.dtype, out_dtype: torch.dtype,
workspace: torch.Tensor, workspace: torch.Tensor,
...@@ -41,19 +43,21 @@ def fp8_gemm( ...@@ -41,19 +43,21 @@ def fp8_gemm(
out_dtype = tex.DType.kFloat32 if fp32_output else TE_DType[out_dtype] out_dtype = tex.DType.kFloat32 if fp32_output else TE_DType[out_dtype]
tex.te_gemm( _ = torch.ops.tex_ts.te_gemm_ts(
A, A,
A_scale_inv, A_scale_inv,
A_fp8_tensor,
A_dtype, A_dtype,
True, # transa True, # transa
B, B,
B_scale_inv, B_scale_inv,
B_fp8_tensor,
B_dtype, B_dtype,
False, # transb False, # transb
out, out,
out_dtype, out_dtype,
bias if use_bias else empty_tensor, bias if use_bias else empty_tensor,
empty_tensor, empty_tensor, # this is pre_gelu_out
False, # grad False, # grad
workspace, workspace,
workspace.shape[0], workspace.shape[0],
...@@ -87,6 +91,7 @@ def gemm( ...@@ -87,6 +91,7 @@ def gemm(
transa = layout[0] == "T" transa = layout[0] == "T"
transb = layout[1] == "T" transb = layout[1] == "T"
empty_tensor = torch.Tensor() empty_tensor = torch.Tensor()
fp8_index = -1 # dummy index
input_dtype = TE_DType[dtype] input_dtype = TE_DType[dtype]
output_dtype = tex.DType.kFloat32 if fp32_output else input_dtype output_dtype = tex.DType.kFloat32 if fp32_output else input_dtype
...@@ -115,13 +120,15 @@ def gemm( ...@@ -115,13 +120,15 @@ def gemm(
bias = bias if use_bias else empty_tensor bias = bias if use_bias else empty_tensor
tex.te_gemm( _ = torch.ops.tex_ts.te_gemm_ts(
A, A,
empty_tensor, empty_tensor,
fp8_index,
input_dtype, input_dtype,
transa, transa,
B, B,
empty_tensor, empty_tensor,
fp8_index,
input_dtype, input_dtype,
transb, transb,
out, out,
...@@ -214,11 +221,12 @@ def fp8_gelu( ...@@ -214,11 +221,12 @@ def fp8_gelu(
otype: tex.DType, otype: tex.DType,
) -> torch.Tensor: ) -> torch.Tensor:
"""GeLU with FP8 output""" """GeLU with FP8 output"""
return tex.fp8_gelu( return torch.ops.tex_ts.fp8_gelu_ts(
inp, inp,
fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
fp8_tensor,
otype, otype,
) )
...@@ -245,6 +253,48 @@ def layernorm_fwd_fp8( ...@@ -245,6 +253,48 @@ def layernorm_fwd_fp8(
) )
def layernorm_fwd_fp8_inf(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""LayerNorm with FP8 output.
This version of layernorm_fwd_fp8 is specialized for inference, and returns
only the normalized output.
"""
ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype)
return ret
def layernorm_fwd_inf(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
) -> torch.Tensor:
"""LayerNorm with FP8 output"""
return torch.ops.tex_ts.layernorm_fwd_inf_ts(
inp,
weight,
bias,
eps,
)
def cast_to_fp8( def cast_to_fp8(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: tex.FP8TensorMeta,
...@@ -252,11 +302,12 @@ def cast_to_fp8( ...@@ -252,11 +302,12 @@ def cast_to_fp8(
otype: tex.DType, otype: tex.DType,
) -> torch.Tensor: ) -> torch.Tensor:
"""Cast input to FP8""" """Cast input to FP8"""
return tex.cast_to_fp8( return torch.ops.tex_ts.cast_to_fp8_ts(
inp, inp,
fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
fp8_tensor,
otype, otype,
) )
...@@ -269,9 +320,10 @@ def cast_from_fp8( ...@@ -269,9 +320,10 @@ def cast_from_fp8(
otype: tex.DType, otype: tex.DType,
) -> torch.Tensor: ) -> torch.Tensor:
"""Cast input from FP8""" """Cast input from FP8"""
return tex.cast_from_fp8( return torch.ops.tex_ts.cast_from_fp8_ts(
inp, inp,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
fp8_tensor,
itype, itype,
otype, otype,
) )
...@@ -94,6 +94,8 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { ...@@ -94,6 +94,8 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
return transformer_engine::DType::kFloat32; return transformer_engine::DType::kFloat32;
case at::kBFloat16: case at::kBFloat16:
return transformer_engine::DType::kBFloat16; return transformer_engine::DType::kBFloat16;
case at::kBool:
return transformer_engine::DType::kByte;
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
......
...@@ -397,6 +397,23 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -397,6 +397,23 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
} }
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype);
return out[0];
}
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
...@@ -428,6 +445,16 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, ...@@ -428,6 +445,16 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
return {ln_out, mu, rsigma}; return {ln_out, mu, rsigma};
} }
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps
) {
// This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps);
return out[0];
}
at::Tensor cast_to_fp8(const at::Tensor &input, at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale, const at::Tensor &scale,
......
...@@ -95,6 +95,15 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -95,6 +95,15 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
transformer_engine::DType otype transformer_engine::DType otype
); );
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
...@@ -102,6 +111,11 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, ...@@ -102,6 +111,11 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
float eps float eps
); );
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps
);
at::Tensor cast_to_fp8(const at::Tensor &input, at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale, const at::Tensor &scale,
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <torch/script.h>
#include "extensions.h"
namespace {
transformer_engine::DType reverse_map_dtype(int64_t dtype) {
if (dtype >= 0 && dtype < static_cast<int64_t>(transformer_engine::DType::kNumTypes)) {
return static_cast<transformer_engine::DType>(dtype);
} else {
NVTE_ERROR("Type not supported.");
}
}
} // namespace
at::Tensor cast_to_fp8_ts(const at::Tensor &input,
const at::Tensor &scale,
const at::Tensor &amax,
const at::Tensor &scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = cast_to_fp8(input,
scale[fp8_tensor],
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
otype_arg);
return output.clone();
}
at::Tensor cast_from_fp8_ts(const at::Tensor &input,
const at::Tensor &scale_inv,
int64_t fp8_tensor,
int64_t itype,
int64_t otype) {
transformer_engine::DType itype_arg = reverse_map_dtype(itype);
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = cast_from_fp8(input,
scale_inv[fp8_tensor],
itype_arg,
otype_arg);
return output.clone();
}
at::Tensor fp8_gelu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = fp8_gelu(input,
scale[fp8_tensor],
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
otype_arg);
return output.clone();
}
at::Tensor te_gemm_ts(at::Tensor A,
at::Tensor A_scale_inverse,
int64_t A_fp8_tensor,
int64_t A_type,
int64_t transa,
at::Tensor B,
at::Tensor B_scale_inverse,
int64_t B_fp8_tensor,
int64_t B_type,
int64_t transb,
at::Tensor D,
int64_t D_type,
at::Tensor bias,
at::Tensor pre_gelu_out,
int64_t grad,
at::Tensor workspace,
int64_t workspaceSize,
int64_t accumulate,
int64_t use_split_accumulator) {
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);
at::Tensor A_scale_inverse_arg = A_scale_inverse.clone();
if (A_scale_inverse.numel())
A_scale_inverse_arg = A_scale_inverse[A_fp8_tensor];
at::Tensor B_scale_inverse_arg = B_scale_inverse.clone();
if (B_scale_inverse.numel())
B_scale_inverse_arg = B_scale_inverse[B_fp8_tensor];
te_gemm(A,
A_scale_inverse_arg,
A_type_arg,
transa_arg,
B,
B_scale_inverse_arg,
B_type_arg,
transb_arg,
D,
D_type_arg,
bias,
pre_gelu_out,
grad_arg,
workspace,
workspaceSize_arg,
accumulate_arg,
use_split_accumulator_arg);
return D;
}
at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
double eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps);
at::Tensor output = layernorm_fwd_fp8_inf(input,
weight,
bias,
eps_float,
scale,
amax,
scale_inv,
otype_arg);
return output.clone();
}
at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
double eps) {
float eps_float = static_cast<float>(eps);
at::Tensor output = layernorm_fwd_inf(input,
weight,
bias,
eps_float);
return output.clone();
}
TORCH_LIBRARY(tex_ts, m) {
m.def("cast_to_fp8_ts", &cast_to_fp8_ts);
m.def("cast_from_fp8_ts", &cast_from_fp8_ts);
m.def("fp8_gelu_ts", &fp8_gelu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
}
This diff is collapsed.
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
"""Fused scaled masked softmax functions""" """Fused scaled masked softmax functions"""
import os import os
from typing import Callable, Tuple, Union from typing import Callable, Tuple, Union
import torch import torch
from torch import nn from torch import nn
import torch._C._onnx as _C_onnx
from torch.onnx import _type_utils
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
...@@ -46,6 +47,36 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -46,6 +47,36 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
return input_grads, None return input_grads, None
@staticmethod
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledUpperTriangMaskedSoftmax symbolic method"""
def triangular_mask():
dtype = _type_utils.JitScalarType.INT64
ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype)
k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
mask = g.op("Trilu", ones, k, upper_i=1)
mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL)
return mask
# Captures the logic of function scaled_upper_triang_masked_softmax_warp_forward
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
mask = triangular_mask()
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
inv_mask = g.op("Sub", one, mask)
neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
softmax_mask = g.op("Mul", mask, neg_tenK)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
class ScaledMaskedSoftmax(torch.autograd.Function): class ScaledMaskedSoftmax(torch.autograd.Function):
""" """
...@@ -78,6 +109,35 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -78,6 +109,35 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
) )
return input_grads, None, None return input_grads, None, None
@staticmethod
def symbolic(
g: torch.Graph,
inputs: torch._C.Value,
mask: torch._C.Value,
scale: float) -> torch._C.Value:
"""ScaledMaskedSoftmax symbolic method"""
# Captures the logic of function scaled_masked_softmax_warp_forward.
# output = softmax(mask(input*scale)
# Computed as:
# masked_scaled = (1 - mask)*(input*scale)
# softmax_mask = mask * -10000
# output = softmax(masked_scaled + softmax_mask)
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
inv_mask = g.op("Sub", one, mask)
# Note: type is hard coded because softmax uses FP16 or BF16
neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
softmax_mask = g.op("Mul", mask, neg_tenK)
masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
class ScaledSoftmax(torch.autograd.Function): class ScaledSoftmax(torch.autograd.Function):
""" """
...@@ -107,6 +167,19 @@ class ScaledSoftmax(torch.autograd.Function): ...@@ -107,6 +167,19 @@ class ScaledSoftmax(torch.autograd.Function):
) )
return input_grads, None, None return input_grads, None, None
@staticmethod
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledSoftmax symbolic method"""
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
out = g.op("Softmax", scaled)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
class FusedScaleMaskSoftmax(nn.Module): class FusedScaleMaskSoftmax(nn.Module):
""" """
...@@ -163,7 +236,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -163,7 +236,7 @@ class FusedScaleMaskSoftmax(nn.Module):
and attn_batches % 4 == 0 # np * b must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4
): ):
if 0 <= sk <= 4096: if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(sk) batch_per_block = self.get_batch_per_block(int(sk))
if self.attn_mask_type == "causal": if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0: if attn_batches % batch_per_block == 0:
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
ONNX symbolic functions for Transformer Engine
Warnings of the type pasted below are a known Pytorch issue
(https://github.com/pytorch/pytorch/issues/81693):
tests/test_onnx_export.py::test_export_cast_ops[112]
/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py:649:
UserWarning: The shape inference of trt::TRT_FP8DequantizeLinear type is missing,
so it may result in wrong shape inference for the exported graph.
Please consider adding it in symbolic function. (Triggered internally at
/opt/pytorch/pytorch/torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1880.)
_C._jit_pass_onnx_graph_shape_type_inference(
Scale tensors are treated as lists ("fs") instead of tensors ("v") because we need to access
specific entries using the index passes as `fp8_tensor`. If you fail to do this you will get
the following error when accessing a sepcific scale element (e.g. `scale_inv[fp8_tensor]`):
TypeError: 'torch._C.Value' object is not subscriptable
"""
import torch
from torch.onnx import symbolic_helper, register_custom_op_symbolic
import torch._C._onnx as _C_onnx
import transformer_engine_extensions as tex
# This file registers custom op symbolic ONNX functions and does not export any symbols.
__all__ = []
# Custom ops spec version
VER = 1
UNSPECIFIED_TYPE = -1
def make_op_name(op_name: str) -> str:
"""custom op name"""
return "trt::" + op_name
def quantize(g, inputs, scale_inv, fp8_tensor):
"""Helper Function for Quantization"""
output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
# Q inputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the input if needed.
if inputs.type().scalarType() == "Half":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT)
scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor]))
q_op = g.op(
make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType(
inputs.type().with_dtype(torch.uint8).with_sizes(output_shape))
return q_op
def dequantize(g, inputs, scale_inv, fp8_tensor, otype):
"""Helper Function for Dequantization"""
output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor]))
out = g.op(make_op_name("TRT_FP8DequantizeLinear"), inputs, scale).setType(
inputs.type().with_dtype(torch.float32).with_sizes(output_shape))
# DQ outputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the output if needed.
if otype == int(tex.DType.kFloat16):
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for cast_to_fp8"""
# pylint: disable=unused-argument
return quantize(g, inputs, scale_inv, fp8_tensor)
@symbolic_helper.parse_args("v", "fs", "i", "i", "i")
def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype):
"""ONNX graph for cast_from_fp8"""
# pylint: disable=unused-argument
return dequantize(g, inputs, scale_inv, fp8_tensor, otype)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_gelu"""
# pylint: disable=unused-argument
gelu = torch.onnx.symbolic_opset9.gelu(g, inputs, "tanh")
out = quantize(g, gelu, scale_inv, fp8_tensor)
return out
@symbolic_helper.parse_args("v", "fs", "i", "i", "i",
"v", "fs", "i", "i", "i",
"v", "i", "v", "v", "i",
"v", "i", "i", "i")
def onnx_te_gemm(
g,
weight,
weight_scale_inverse,
weight_fp8_tensor,
weight_type,
trans_weight,
inputs,
input_scale_inverse,
input_fp8_tensor,
input_type,
trans_input,
out,
out_type,
bias,
pre_gelu_out,
grad,
workspace,
workspaceSize,
accumulate,
use_split_accumulator):
"""ONNX graph for te_gemm"""
# pylint: disable=unused-argument
is_fp16 = bias.type().scalarType() == "Half"
if input_type == int(tex.DType.kFloat8E4M3):
inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, UNSPECIFIED_TYPE)
if weight_type == int(tex.DType.kFloat8E4M3):
weight = dequantize(g, weight, weight_scale_inverse, weight_fp8_tensor, UNSPECIFIED_TYPE)
output = g.op("Gemm", inputs, weight, transA_i=trans_input, transB_i=trans_weight)
empty_tensor_size = [0]
bias_empty = torch.onnx.symbolic_helper._get_tensor_sizes(bias) == empty_tensor_size
pre_gelu_out_empty = torch.onnx.symbolic_helper._get_tensor_sizes(pre_gelu_out) \
== empty_tensor_size
if not bias_empty:
if pre_gelu_out_empty:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
output = g.op('Add', output, bias)
else:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
output = g.op('Add', output, bias)
output = torch.onnx.symbolic_opset9.gelu(g, output)
else:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return output
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i")
def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for layernorm_fwd_fp8"""
# pylint: disable=unused-argument
ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
return fp8_ln
@symbolic_helper.parse_args("v", "v", "v", "f")
def onnx_layernorm_fwd(g, inputs, weight, bias, eps):
"""ONNX graph for layernorm_fwd"""
# pylint: disable=unused-argument
normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
if normalized_shape is None:
ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs)
assert ndim is not None
normalized_shape = list(range(0, ndim))
# Normalization axis = 0, so normalized_shape uses all dims except dim = 0
normalized_shape = normalized_shape[1:]
ln = torch.onnx.symbolic_opset9.layer_norm(
g,
inputs,
normalized_shape,
weight,
bias,
eps,
False # cudnn_enable (not relevant)
)
return ln
register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER)
register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER)
register_custom_op_symbolic('tex_ts::fp8_gelu_ts', onnx_fp8_gelu, VER)
register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment