Unverified Commit 6c9ce179 authored by asfiyab-nvidia's avatar asfiyab-nvidia Committed by GitHub
Browse files

Add ONNX export support for TE modules (#41)



* Add ONNX export support for TE modules (#1)

* Add TorchScript Operators
* Add symbolic methods to ONNX exporter
* Add tests for the ONNX export
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fixes for pylint tests
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix pylint warning in softmax.py
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* move FP8 ORT lib inside tests/
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* enable cross attention tests
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* refactor code by @nzmora
* Increase layernorm FP16 threshold
* Normalize onnx file names: _ separates configs; - separates words in a single config
* Add get_attn_mask_str and fix mask string
* Add missing ONNX files
* Moved generated ONNX files to tests/gen_onnx_models/
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix merge conflict changes
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix Q/DQ scale input
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* enable FP16 config when bias is disabled
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix pylint check errors
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* updates
1. remove List import for pylint failure
2. address comments: remove state tensors from GPU
3. address comments: Update reverse_map_dtype function and add to namespace
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* minor fix: coding guidelines
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* changes:
1. skip FP8 tests on  non-hopper devices
2. minor fix for C++ lint check
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix onnxruntime version
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* minor fix: add space between code and comment
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* changes
1. update copyrights
2. update path to ORT .so
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarasfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com>
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
Signed-off-by: default avatarasfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e2ad34e9
......@@ -8,6 +8,7 @@
*.nsys-rep
*.ncu-rep
*.sqlite
*.onnx
.eggs
build/
*.so
......
......@@ -6,5 +6,6 @@ set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5
pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/test_transformerengine.py
pytest -v -s $TE_PATH/tests/test_onnx_export.py
......@@ -14,7 +14,7 @@ from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext
from distutils.version import LooseVersion
from distutils.file_util import copy_file
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
path = os.path.dirname(os.path.realpath(__file__))
......@@ -85,6 +85,7 @@ include_dirs = make_abs_path(include_dirs)
pytorch_sources = [
"transformer_engine/pytorch/csrc/extensions.cu",
"transformer_engine/pytorch/csrc/common.cu",
"transformer_engine/pytorch/csrc/ts_fp8_op.cpp",
]
pytorch_sources = make_abs_path(pytorch_sources)
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
This file contains tests for exporting TransformerEngine models to ONNX.
"""
import os
import pytest
import warnings
import numpy as np
import math
import onnxruntime as ort
import torch
from torch import nn as nn
from typing import Union, Tuple
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, fp8_gelu, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.module import get_workspace
import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs
from transformer_engine.pytorch.utils import get_default_init_method
# Directory where generated ONNX test models are stored.
TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
ONNX_FILES_DIR = os.path.join(TESTS_DIR, "./gen_onnx_models")
# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT).
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
# ScaledUpperTriangMaskedSoftmax is exported via ONNX::Trilu which was introduced in opset 14.
TRILU_OPSET = 14
# Opset used in the ONNX files generated by the tests.
OPSET = 15
assert OPSET >= TRILU_OPSET
skip_FP8 = pytest.mark.skipif(
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
reason="Device compute capability 9.x required for FP8 execution.",
)
def create_fp8_recipe():
return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
def do_export(
model: torch.nn.Module,
inp: torch.Tensor,
fname: str,
use_fp8: bool=True,
opset: int=OPSET,
input_names: list=["input"],
output_names: list=["output"],
):
"""Export to ONNX"""
fp8_recipe = create_fp8_recipe()
with torch.inference_mode(), te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings():
warnings.filterwarnings(
action='ignore',
category=torch.jit.TracerWarning,
module=r'.*'
)
model.cuda().eval()
os.makedirs(ONNX_FILES_DIR, exist_ok=True)
fname = os.path.join(ONNX_FILES_DIR, fname)
torch.onnx.export(model,
inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,),
fname,
verbose=False,
opset_version=opset,
input_names=input_names,
output_names=output_names,
do_constant_folding=True,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH)
def to_numpy(tensor):
return tensor.cpu().numpy()
def set_layer_scale(module: torch.nn.Module, scale: float):
module.fp8_init()
module.fp8_meta["scaling_fwd"].scale = torch.ones(
2, dtype=torch.float32, device="cuda") / scale
module.fp8_meta["scaling_fwd"].scale_inv = torch.ones(
2, dtype=torch.float32, device="cuda") * scale
def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool):
"""Transformer Engine forward prpoagtation.
Return results after copying to the CPU and converting to numpy.
"""
fp8_recipe = create_fp8_recipe()
with torch.inference_mode(), te.fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings():
te_outputs = model(*inps if isinstance(inps, tuple) else (inps,))
if not isinstance(te_outputs, tuple):
te_outputs = (te_outputs,)
te_outputs_np = [to_numpy(te_output) for te_output in te_outputs]
return te_outputs_np
def validate_result(
fname: str,
inps: Union[Tuple[torch.Tensor], torch.Tensor],
model: torch.nn.Module,
atol: float=1.e-8, # np.isclose default atol
rtol: float=1.e-5, # np.isclose default rtol
max_errors_printed: int=10,
is_fp8: bool=False,
):
"""Validate the outputs of an ONNX model vs. ONNX Runtime."""
def create_ort_session(fname: str, is_fp8: bool):
def load_custom_ops(session_opts: ort.SessionOptions):
"""For FP8 validation with ORT we need to load our custom FP8 Q/DQ extension."""
if not os.path.exists(ORT_CUSTOM_OPS_LIB):
raise FileNotFoundError(f"Unable to find {ORT_CUSTOM_OPS_LIB}")
session_opts.register_custom_ops_library(ORT_CUSTOM_OPS_LIB)
print("registered custom FP8 Q/DQ ops!")
"""Create an ONNX Runtime session for validation."""
if is_fp8:
sess_options = ort.SessionOptions()
load_custom_ops(sess_options)
# Model loading successfully indicates that the custom op node could be resolved successfully
s = ort.InferenceSession(fname, sess_options=sess_options)
else:
s = ort.InferenceSession(fname)
return s
def create_ort_input_dict(session, inps):
inp_dict = {}
if isinstance(inps, tuple) or isinstance(inps, list):
nonetype_inputs = 0
for idx, inp in enumerate(inps):
if inp is None:
nonetype_inputs += 1
continue
inp_dict[session.get_inputs()[idx - nonetype_inputs].name] = to_numpy(inp)
else:
inp_dict[session.get_inputs()[0].name] = to_numpy(inps)
return inp_dict
# Run ORT session and TE model.
fname = os.path.join(ONNX_FILES_DIR, fname)
ort_s = create_ort_session(fname, is_fp8)
onnx_outputs = ort_s.run(None, input_feed=create_ort_input_dict(ort_s, inps))
te_outputs = te_infer(model, inps, is_fp8)
# Compare ORT and TE outputs.
assert len(onnx_outputs) == len(te_outputs)
for onnx_output, te_output in zip(onnx_outputs, te_outputs):
# Compare ORT and PyTorch outputs.
# np.isclose: abs(a - b) <= (atol + rtol * abs(b))
ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol)
mismatches = ac.nonzero()
mismatched_ids = [loc for loc in zip(*mismatches)]
if mismatched_ids:
# Log some information in case of error.
print("*" * 100)
print(onnx_output.shape)
nb_vals = min(len(mismatched_ids), max_errors_printed)
print(f"Detected {len(mismatched_ids)} diverging values.\nShowing first {nb_vals} errors (ONNX -- TE):")
abs_err = abs(onnx_output - te_output)
for loc in mismatched_ids[:nb_vals]:
ref = te_output[loc]
print(f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} > {atol + rtol * abs(ref)}")
raise ValueError(f"Output validation of {fname} failed with {len(mismatched_ids)} errors")
def create_meta(scale_factor: float, size: int=1):
meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
return meta
def dtype2str(dtype: torch.dtype):
return {
torch.float32: "_fp32",
torch.float16: "_fp16",
torch.bfloat16: "_bf16",
}[dtype]
def as_te_type(dtype: torch.dtype):
return {
torch.float32: tex.DType.kFloat32,
torch.float16: tex.DType.kFloat16,
torch.bfloat16: tex.DType.kBFloat16,
}[dtype]
def get_attn_mask_str(use_mask, attn_mask_type):
# See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names.
if attn_mask_type is None:
return "_mask" if use_mask else "_no-mask"
attn_mask_str = "_padding-no-mask"
attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str
attn_mask_str = "_padding-mask" if use_mask and attn_mask_type == "padding" else attn_mask_str
return attn_mask_str
@skip_FP8
@pytest.mark.parametrize("scale_factor, atol", [
(1, 1e-7),
(224, 1e-7)
])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
def test_export_cast_ops(scale_factor: float, atol: float, precision: torch.dtype):
class TestFP8_QDQ(nn.Module):
def __init__(self):
super().__init__()
self.fp8_tensor = 0
self.meta = create_meta(scale_factor)
self.highprec_type = as_te_type(precision)
self.fp8_type = tex.DType.kFloat8E4M3
def forward(self, inp):
ret = cast_to_fp8(
inp,
self.meta,
self.fp8_tensor,
self.fp8_type)
ret = cast_from_fp8(
ret,
self.meta,
self.fp8_tensor,
self.fp8_type,
self.highprec_type)
return ret
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision)
high_prec_str = dtype2str(precision)
fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_QDQ()
do_export(model, inp, fname)
validate_result(fname, inp, model, atol=atol, is_fp8=True)
@skip_FP8
@pytest.mark.parametrize("scale_factor", [448])
@pytest.mark.parametrize(
"precision, atol", [
[torch.float32, 1e-7],
[torch.float16, 2e-3]
])
def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float):
class TestFP8_Gelu(nn.Module):
def __init__(self):
super().__init__()
self.fp8_tensor = 0
self.meta = create_meta(scale_factor)
self.highprec_type = as_te_type(precision)
self.fp8_type = tex.DType.kFloat8E4M3
def forward(self, inp):
ret = fp8_gelu(
inp,
self.meta,
self.fp8_tensor,
self.fp8_type)
ret = cast_from_fp8(
ret,
self.meta,
self.fp8_tensor,
self.fp8_type,
self.highprec_type)
return ret
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision)
high_prec_str = dtype2str(precision)
fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_Gelu()
do_export(model, inp, fname)
validate_result(fname, inp, model, rtol=1e-1, atol=atol, is_fp8=True)
@pytest.mark.parametrize("scale_factors",
[(224, 224,),
])
@pytest.mark.parametrize(
"precision, use_fp8, use_bias, use_gelu", [
(torch.float32, False, False, False),
(torch.float16, False, False, False),
(torch.float32, False, True, False),
(torch.float16, False, True, False),
(torch.float32, False, True, True),
(torch.float16, False, True, True),
# For FP8 GEMM GeLU is not used.
(torch.float32, True, False, False),
(torch.float16, True, False, False),
# When enabling bias we must use float16 or bfloat16 (because of kernel limitations)
(torch.float16, True, True, False),
(torch.bfloat16, True, True, False),
])
def test_export_gemm(
precision, # Precision of inputs, weights, output and bias
use_fp8,
use_bias,
use_gelu,
scale_factors
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
class TestFP8_GEMM(nn.Module):
def __init__(self, precision, use_bias, gelu, scale_factors):
super().__init__()
self.use_bias = use_bias
self.gelu = gelu
self.precision = precision
self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
nb_inp_scales, nb_weight_scales = 1, out_features
act_scale_factor, weight_scale_factor = scale_factors
self.meta_inp = create_meta(act_scale_factor, nb_inp_scales)
self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales)
bias_size = nb_weight_scales
self.bias = torch.randn(bias_size, dtype=precision, device="cuda")
self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda")
self.inp_type = tex.DType.kFloat8E4M3
self.weights_type = tex.DType.kFloat8E4M3
self.outp_type = precision
def forward(self, inp, weight):
inp_fp8 = cast_to_fp8(
inp,
self.meta_inp,
self.fp8_tensor_inp,
self.inp_type)
weight_fp8 = cast_to_fp8(
weight,
self.meta_weight,
self.fp8_tensor_weight,
self.weights_type)
ret = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
self.inp_type,
inp_fp8,
self.meta_inp.scale_inv,
self.fp8_tensor_inp,
self.weights_type,
self.outp_type,
get_workspace(),
bias=self.bias,
use_bias=self.use_bias,
fp32_output=(self.precision==torch.float32),
use_split_accumulator=False)
return ret
class Test_GEMM(nn.Module):
def __init__(self, precision, use_bias=False, gelu=False):
super().__init__()
self.use_bias = use_bias
self.gelu = gelu
self.precision = precision
bias_size = out_features
self.bias = torch.randn(bias_size, dtype=precision, device="cuda")
self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda")
def forward(self, inp, weight):
outp_type = self.precision
# note: due to logic in lines 104:116 and L129 in cpp_extensions.py
# it appears either bias OR gelu can be activated, not both
ret, _, _ = gemm(
weight,
inp,
outp_type,
get_workspace(),
# test bias
bias=self.bias,
use_bias=self.use_bias,
# test gelu
gelu=self.gelu,
gelu_input=self.gelu_input,
grad=False # only True for backward pass
)
return ret
# If gelu is applied then bias must be added, as defined by TE kernel.
if use_gelu: assert use_bias
# Set dimensions (these are arbitrary).
out_features = 128
hidden_size = 256
in_features = 64
inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda")
weight = torch.randn(out_features, in_features, dtype=precision, device="cuda")
fp8_str = "_fp8" if use_fp8 else ""
bias_str = "_bias" if use_bias else ""
gelu_str = "_gelu" if use_gelu else ""
high_prec_str = dtype2str(precision)
fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx"
if use_fp8:
model = TestFP8_GEMM(precision, use_bias, use_gelu, scale_factors)
do_export(model, (inp, weight), fname, use_fp8)
if precision not in (torch.bfloat16, torch.float16):
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=1e-2, is_fp8=True)
else:
model = Test_GEMM(precision, use_bias, use_gelu)
do_export(model, (inp, weight), fname, use_fp8)
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2)
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
def test_export_layernorm(
use_fp8: bool,
scale_factor: float,
precision: torch.dtype
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
# Set dimensions (these are arbitrary).
inp_shape = [64, 32]
class Test_Layernorm(nn.Module):
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda")
self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda")
self.eps = 1e-6 # An arbitrary small value
def forward(self, inp):
ret = texcpp.layernorm_fwd_inf(
inp,
self.weight,
self.bias,
self.eps)
return ret
class TestFP8_Layernorm(nn.Module):
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda")
self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda")
self.eps = 1e-6 # An arbitrary small value
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(scale_factor)
self.fp8_type = tex.DType.kFloat8E4M3
def forward(self, inp):
ret = texcpp.layernorm_fwd_fp8_inf(
inp,
self.weight,
self.bias,
self.eps,
self.meta,
self.fp8_tensor,
self.fp8_type)
ret = cast_from_fp8(
ret,
self.meta,
self.fp8_tensor,
self.fp8_type,
tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16)
return ret
inp = torch.randn(*inp_shape, device="cuda", dtype=precision)
model = TestFP8_Layernorm() if use_fp8 else Test_Layernorm()
high_prec_str = dtype2str(precision)
fp8_str = f"_fp8-{scale_factor}" if use_fp8 else ""
fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx"
do_export(model, inp, fname, use_fp8=use_fp8)
if precision not in (torch.bfloat16, ):
# TODO: FP32 has a small threshold (1e-5)
validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8)
@skip_FP8
@pytest.mark.parametrize("softmax_def", [
softmax_defs.ScaledUpperTriangMaskedSoftmax,
softmax_defs.ScaledMaskedSoftmax,
softmax_defs.ScaledSoftmax,
])
# Softmax kernel only supports FP16 or BF16!
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16])
def test_export_softmax(softmax_def, precision):
class Test_Softmax(nn.Module):
def __init__(self, softmax_function, mask_inp=False):
super().__init__()
self.softmax_fn = softmax_function
self.mask_inp = mask_inp
def forward(self, inp, mask):
scale_factor = 8 # arbitrary value
if self.mask_inp:
ret = self.softmax_fn.apply(inp, mask, scale_factor)
else:
ret = self.softmax_fn.apply(inp, scale_factor)
return ret
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
mask = None
input_names = ["input"]
inp_shape = [hidden_size, in_features, in_features, in_features]
if softmax_def == softmax_defs.ScaledUpperTriangMaskedSoftmax:
inp_shape = [hidden_size, in_features, in_features]
kernel_str = "ScaledUpperTriangMaskedSoftmax"
model = Test_Softmax(softmax_def)
elif softmax_def == softmax_defs.ScaledMaskedSoftmax:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(hidden_size, 1, in_features, in_features, device="cuda", dtype=precision)
mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
input_names.append("mask")
kernel_str = "ScaledMaskedSoftmax"
model = Test_Softmax(softmax_def, mask_inp=True)
elif softmax_def == softmax_defs.ScaledSoftmax:
kernel_str = "ScaledSoftmax"
model = Test_Softmax(softmax_def)
input_tensor = torch.randn(*inp_shape, device="cuda")
input_tensor = input_tensor.to(torch.bfloat16) if precision == torch.bfloat16 else input_tensor.half()
high_prec_str = dtype2str(precision)
fname = f"{kernel_str}{high_prec_str}.onnx"
inp = (input_tensor, mask)
do_export(model, inp, fname, input_names=input_names)
if precision != torch.bfloat16:
validate_result(fname, inp, model, atol=1e-3)
@pytest.mark.parametrize("scale_factor", [1])
@pytest.mark.parametrize("use_fp8", [False, True])
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize(
"precision, use_bias",[
(torch.float32, False),
(torch.float32, True),
(torch.float16, False),
(torch.float16, True),
# Todo: cannot configure BF16 when bias is disabled (ORT issue?)
(torch.bfloat16, False),
# Todo: cannot configure BF16 when bias is enabled (ORT issue?)
# (torch.bfloat16, True),
])
def test_export_linear(
scale_factor: float,
use_fp8: bool,
use_bias: bool,
return_bias: bool,
precision: torch.dtype
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
# Set dimensions (these are arbitrary).
in_features = 64
out_features = 256
hidden_size = 256
class Test_Linear(nn.Module):
def __init__(self,
in_features,
out_features,
use_bias,
return_bias,
precision
):
super().__init__()
self.linear = te.Linear(
in_features,
out_features,
bias=use_bias,
return_bias=return_bias,
params_dtype=precision
)
def forward(self, inp):
ret = self.linear(inp)
return ret
inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if use_fp8 else ""
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=use_fp8):
model = Test_Linear(
in_features,
out_features,
use_bias,
return_bias,
precision
).to(device='cuda')
if use_fp8:
set_layer_scale(model.linear, scale_factor)
do_export(model, inp, fname, use_fp8)
if precision in (torch.bfloat16, ):
return
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
else:
validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("use_fp8", [False, True])
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize(
"precision, use_bias",[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
])
def test_export_layernorm_linear(
scale_factor: float,
use_fp8: bool,
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
# Set dimensions (these are arbitrary).
in_features = 64
out_features = 256
hidden_size = 256
inp = torch.randn(in_features, out_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if use_fp8 else ""
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=use_fp8):
model = te.LayerNormLinear(
hidden_size,
3 * hidden_size,
bias=use_bias,
return_bias=return_bias,
return_layernorm_output=return_layernorm_output,
params_dtype=precision,
).to(device='cuda')
if use_fp8:
set_layer_scale(model, scale_factor)
do_export(model, inp, fname, use_fp8)
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
elif precision not in (torch.bfloat16,):
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("use_fp8", [False, True])
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize(
"precision, use_bias",[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
])
def test_export_layernorm_mlp(
scale_factor: float,
use_fp8: bool,
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
# Set dimensions (these are arbitrary).
in_features = 64
out_features = 256
hidden_size = 256
ffn_hidden_size = 256
inp = torch.randn(in_features, out_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if use_fp8 else ""
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=use_fp8):
model = te.LayerNormMLP(
hidden_size,
ffn_hidden_size,
bias=use_bias,
return_bias=return_bias,
return_layernorm_output=return_layernorm_output,
params_dtype=precision,
).to(device='cuda')
if use_fp8:
set_layer_scale(model, scale_factor)
do_export(model, inp, fname, use_fp8)
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
else:
validate_result(fname, inp, model, atol=2e-2, is_fp8=use_fp8)
@skip_FP8
@pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", [
(torch.float32, False, None), # calls forward_torch_softmax
(torch.float32, True, None), # calls forward_torch_softmax
(torch.float16, False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(torch.float16, True, "padding"), # calls ScaledMaskedSoftmax
(torch.float16, False, "padding"), # calls ScaledSoftmax
])
@pytest.mark.parametrize("attention_softmax_in_fp32",
[True, False])
@pytest.mark.parametrize("apply_query_key_layer_scaling",
[True, False])
def test_export_core_attention(
precision: torch.dtype,
use_mask: bool,
attn_mask_type: str,
attention_softmax_in_fp32: bool,
apply_query_key_layer_scaling: bool,
):
# Set dimensions (these are arbitrary).
kv_channels = 64
num_attention_heads = 1
qkv_size = (2048, 4, num_attention_heads, kv_channels)
query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
value_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
input_names = ["query", "key", "value"]
attention_mask = None
if use_mask:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
input_names.append("attention_mask")
inp = (query_layer, key_layer, value_layer, attention_mask)
sm_prec_str = "_sm-fp32" if attention_softmax_in_fp32 else "_sm-fp16"
qk_scaling_str = "_qk-scaling" if apply_query_key_layer_scaling else ""
mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision)
fname = f"te.core_attention{mask_str}{qk_scaling_str}{sm_prec_str}{high_prec_str}.onnx"
if attn_mask_type is None:
attn_mask_type = 'causal'
model = te.transformer.CoreAttention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
attention_dropout=0.5,
attn_mask_type=attn_mask_type,
attention_softmax_in_fp32=attention_softmax_in_fp32,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
).to(device='cuda')
do_export(model,
inp,
fname,
input_names=input_names,
use_fp8=True)
validate_result(fname, inp, model, atol=1e-2)
test_configs_multihead_attention = [
#"use_mask, attn_mask_type"
(False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(True, "padding"), # calls ScaledMaskedSoftmax
(False, "padding"), # calls ScaledSoftmax
]
test_configs_attention_type = [
#"input_layernorm, attention_type, fuse_qkv_params"
(True, "self", True),
(False, "self", True),
(True, "self", False),
(False, "self", False),
# disabled because query_bias (reqd for cross attention) is defined when fuse_qkv_params is False
# (True, "cross", True),
# (False, "cross", True),
(True, "cross", False),
# disabled because TypeError: cannot assign 'transformer_engine.pytorch.module.Linear'
# as parameter 'query' (torch.nn.Parameter or None expected)
# (False, "cross", False),
]
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize("input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type)
def test_export_multihead_attention(
use_fp8: bool,
use_mask: bool,
attn_mask_type: str,
precision: torch.dtype,
return_layernorm_output: bool,
input_layernorm: bool,
attention_type: str,
fuse_qkv_params: bool
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
hidden_size = 256
sequence_length = 128
batch_size = 4
num_attention_heads = 32
kv_channels = 8
attention_dropout = 0.1
layernorm_epsilon = 1e-5
init_method = output_layer_init_method = get_default_init_method()
attention_args = (
hidden_size,
num_attention_heads,
kv_channels,
attention_dropout,
layernorm_epsilon,
init_method,
output_layer_init_method,
)
hidden_states = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
encoder_output = None
if attention_type == "cross":
encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
inp = (hidden_states, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"]
fp8_str = "_fp8" if use_fp8 else ""
dtype_str = dtype2str(precision)
attn_type_str = "_self-attention" if attention_type == "self" else "_cross-attention"
fuse_qkv_str = "_fused-qkv" if fuse_qkv_params else ""
attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
input_ln_str = "_input-ln" if input_layernorm else ""
fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx"
model = te.transformer.MultiHeadAttention(
*attention_args,
attn_mask_type=attn_mask_type,
params_dtype=precision,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params,
).to(device='cuda')
do_export(model, inp, fname, use_fp8, input_names=input_names)
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
elif precision != torch.float16:
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8)
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("output_layernorm", [
#True, # TO DO: handle this
False
])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("apply_query_key_layer_scaling", [True, False])
def test_export_transformer_layer(
use_fp8: bool,
use_mask: bool,
attn_mask_type: str,
output_layernorm: bool,
precision: torch.dtype,
fuse_qkv_params: bool,
apply_query_key_layer_scaling: bool
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
# Layer configuration
hidden_size = 64
sequence_length = 128
batch_size = 1
ffn_hidden_size = 256
num_attention_heads = 4
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
input_names = ["input"]
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
input_names.append("attention_mask")
inp = (input_tensor, attention_mask)
fp8_str = "_fp8" if use_fp8 else ""
fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
qk_scaling_str = "_qk-scaling" if apply_query_key_layer_scaling else ""
high_prec_str = dtype2str(precision)
attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{qk_scaling_str}{high_prec_str}.onnx"
model = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
apply_query_key_layer_scaling=apply_query_key_layer_scaling).to(device='cuda')
do_export(model, inp, fname, use_fp8)
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
elif precision != torch.float16:
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8)
......@@ -10,3 +10,12 @@ from .module import LayerNorm
from .transformer import TransformerLayer
from .fp8 import fp8_autocast
from .distributed import checkpoint
# Register custom op symbolic ONNX functions
from .te_onnx_extensions import (
onnx_cast_to_fp8,
onnx_cast_from_fp8,
onnx_fp8_gelu,
onnx_te_gemm,
onnx_layernorm_fwd_fp8,
onnx_layernorm_fwd,
)
......@@ -12,9 +12,11 @@ from .constants import TE_DType
def fp8_gemm(
A: torch.Tensor,
A_scale_inv: torch.Tensor,
A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
A_dtype: tex.DType,
B: torch.Tensor,
B_scale_inv: torch.Tensor,
B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
B_dtype: tex.DType,
out_dtype: torch.dtype,
workspace: torch.Tensor,
......@@ -41,19 +43,21 @@ def fp8_gemm(
out_dtype = tex.DType.kFloat32 if fp32_output else TE_DType[out_dtype]
tex.te_gemm(
_ = torch.ops.tex_ts.te_gemm_ts(
A,
A_scale_inv,
A_fp8_tensor,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor,
B_dtype,
False, # transb
out,
out_dtype,
bias if use_bias else empty_tensor,
empty_tensor,
empty_tensor, # this is pre_gelu_out
False, # grad
workspace,
workspace.shape[0],
......@@ -87,6 +91,7 @@ def gemm(
transa = layout[0] == "T"
transb = layout[1] == "T"
empty_tensor = torch.Tensor()
fp8_index = -1 # dummy index
input_dtype = TE_DType[dtype]
output_dtype = tex.DType.kFloat32 if fp32_output else input_dtype
......@@ -115,13 +120,15 @@ def gemm(
bias = bias if use_bias else empty_tensor
tex.te_gemm(
_ = torch.ops.tex_ts.te_gemm_ts(
A,
empty_tensor,
fp8_index,
input_dtype,
transa,
B,
empty_tensor,
fp8_index,
input_dtype,
transb,
out,
......@@ -214,11 +221,12 @@ def fp8_gelu(
otype: tex.DType,
) -> torch.Tensor:
"""GeLU with FP8 output"""
return tex.fp8_gelu(
return torch.ops.tex_ts.fp8_gelu_ts(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
)
......@@ -245,6 +253,48 @@ def layernorm_fwd_fp8(
)
def layernorm_fwd_fp8_inf(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""LayerNorm with FP8 output.
This version of layernorm_fwd_fp8 is specialized for inference, and returns
only the normalized output.
"""
ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype)
return ret
def layernorm_fwd_inf(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
) -> torch.Tensor:
"""LayerNorm with FP8 output"""
return torch.ops.tex_ts.layernorm_fwd_inf_ts(
inp,
weight,
bias,
eps,
)
def cast_to_fp8(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
......@@ -252,11 +302,12 @@ def cast_to_fp8(
otype: tex.DType,
) -> torch.Tensor:
"""Cast input to FP8"""
return tex.cast_to_fp8(
return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
)
......@@ -269,9 +320,10 @@ def cast_from_fp8(
otype: tex.DType,
) -> torch.Tensor:
"""Cast input from FP8"""
return tex.cast_from_fp8(
return torch.ops.tex_ts.cast_from_fp8_ts(
inp,
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.scale_inv,
fp8_tensor,
itype,
otype,
)
......@@ -94,6 +94,8 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
return transformer_engine::DType::kFloat32;
case at::kBFloat16:
return transformer_engine::DType::kBFloat16;
case at::kBool:
return transformer_engine::DType::kByte;
default:
NVTE_ERROR("Invalid type");
}
......
......@@ -397,6 +397,23 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
}
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype);
return out[0];
}
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
......@@ -428,6 +445,16 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
return {ln_out, mu, rsigma};
}
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps
) {
// This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps);
return out[0];
}
at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale,
......
......@@ -95,6 +95,15 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
transformer_engine::DType otype
);
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight,
......@@ -102,6 +111,11 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
float eps
);
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps
);
at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale,
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <torch/script.h>
#include "extensions.h"
namespace {
transformer_engine::DType reverse_map_dtype(int64_t dtype) {
if (dtype >= 0 && dtype < static_cast<int64_t>(transformer_engine::DType::kNumTypes)) {
return static_cast<transformer_engine::DType>(dtype);
} else {
NVTE_ERROR("Type not supported.");
}
}
} // namespace
at::Tensor cast_to_fp8_ts(const at::Tensor &input,
const at::Tensor &scale,
const at::Tensor &amax,
const at::Tensor &scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = cast_to_fp8(input,
scale[fp8_tensor],
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
otype_arg);
return output.clone();
}
at::Tensor cast_from_fp8_ts(const at::Tensor &input,
const at::Tensor &scale_inv,
int64_t fp8_tensor,
int64_t itype,
int64_t otype) {
transformer_engine::DType itype_arg = reverse_map_dtype(itype);
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = cast_from_fp8(input,
scale_inv[fp8_tensor],
itype_arg,
otype_arg);
return output.clone();
}
at::Tensor fp8_gelu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = fp8_gelu(input,
scale[fp8_tensor],
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
otype_arg);
return output.clone();
}
at::Tensor te_gemm_ts(at::Tensor A,
at::Tensor A_scale_inverse,
int64_t A_fp8_tensor,
int64_t A_type,
int64_t transa,
at::Tensor B,
at::Tensor B_scale_inverse,
int64_t B_fp8_tensor,
int64_t B_type,
int64_t transb,
at::Tensor D,
int64_t D_type,
at::Tensor bias,
at::Tensor pre_gelu_out,
int64_t grad,
at::Tensor workspace,
int64_t workspaceSize,
int64_t accumulate,
int64_t use_split_accumulator) {
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);
at::Tensor A_scale_inverse_arg = A_scale_inverse.clone();
if (A_scale_inverse.numel())
A_scale_inverse_arg = A_scale_inverse[A_fp8_tensor];
at::Tensor B_scale_inverse_arg = B_scale_inverse.clone();
if (B_scale_inverse.numel())
B_scale_inverse_arg = B_scale_inverse[B_fp8_tensor];
te_gemm(A,
A_scale_inverse_arg,
A_type_arg,
transa_arg,
B,
B_scale_inverse_arg,
B_type_arg,
transb_arg,
D,
D_type_arg,
bias,
pre_gelu_out,
grad_arg,
workspace,
workspaceSize_arg,
accumulate_arg,
use_split_accumulator_arg);
return D;
}
at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
double eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps);
at::Tensor output = layernorm_fwd_fp8_inf(input,
weight,
bias,
eps_float,
scale,
amax,
scale_inv,
otype_arg);
return output.clone();
}
at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
double eps) {
float eps_float = static_cast<float>(eps);
at::Tensor output = layernorm_fwd_inf(input,
weight,
bias,
eps_float);
return output.clone();
}
TORCH_LIBRARY(tex_ts, m) {
m.def("cast_to_fp8_ts", &cast_to_fp8_ts);
m.def("cast_from_fp8_ts", &cast_from_fp8_ts);
m.def("fp8_gelu_ts", &fp8_gelu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
}
......@@ -4,12 +4,14 @@
"""Top level Transformer Engine PyTorch modules"""
import os
import pickle
import warnings
from abc import ABC, abstractmethod
from typing import Union, Optional, Callable, Tuple, Dict, List, Any, Mapping
from typing import Union, Optional, Callable, Tuple, Dict, Any, Mapping
from functools import partial
from contextlib import contextmanager
import numpy as np
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
......@@ -70,6 +72,8 @@ from .cpp_extensions import (
fp8_gelu,
fp8_cast_transpose_bgrad_dgelu_fused,
layernorm_fwd_fp8,
layernorm_fwd_fp8_inf,
layernorm_fwd_inf,
cast_to_fp8,
cast_from_fp8,
)
......@@ -192,8 +196,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_meta_tensor(True)
self.set_meta_tensor(False)
def get_extra_state(self) -> Union[List[Any], None]:
def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing."""
state = None
if self.fp8:
state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
......@@ -210,10 +215,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
extra[k] = v
state["extra_fp8_variables"] = extra
return state
return None
state_serialized = pickle.dumps(state)
state_tensor = torch.tensor(np.frombuffer(state_serialized, dtype=np.uint8))
def set_extra_state(self, state: Union[List[Any], None]) -> None:
return state_tensor
def set_extra_state(self, state: torch.Tensor) -> None:
"""Load previous state."""
if state is None:
return
......@@ -252,6 +259,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["autocast_id_bwd"] = state[9]
return
if isinstance(state, torch.Tensor):
state = pickle.loads(state.detach().numpy().tobytes())
if state is None:
return
# Restore global FP8 buffer states.
set_global_fp8_buffer(state["global_fp8_buffer"])
set_global_fp8_recompute_buffer(state["global_fp8_recompute_buffer"])
......@@ -541,13 +553,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_dtype_backward,
)
else:
grad_output_t = None
grad_output_c = cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
grad_output_t = None
grad_bias = None
return grad_output_mat, grad_output_c, grad_output_t, grad_bias
......@@ -557,6 +569,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Needs override."""
class _LayerNormLinear(torch.autograd.Function):
"""LayerNormLinear semi-top level module
Calls custom cuda extensions.
......@@ -584,6 +597,7 @@ class _LayerNormLinear(torch.autograd.Function):
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
return_layernorm_output: bool,
is_training: bool
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -604,6 +618,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
if is_training:
ln_out, mu, rsigma = layernorm_fwd_fp8(
inputmat,
ln_weight,
......@@ -614,9 +629,26 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward,
)
else:
mu = rsigma = None
ln_out = layernorm_fwd_fp8_inf(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
if is_training:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps
)
else:
ln_out_return, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps
), None, None
ln_out = cast_to_fp8(
ln_out_return,
fp8_meta["scaling_fwd"],
......@@ -624,7 +656,12 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward,
)
else:
if is_training:
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps)
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps
), None, None
ln_out_return = ln_out
# Column Parallel Linear
......@@ -642,6 +679,7 @@ class _LayerNormLinear(torch.autograd.Function):
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights:
if is_training:
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
......@@ -650,13 +688,22 @@ class _LayerNormLinear(torch.autograd.Function):
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
else:
weight_t_fp8 = None
weight_fp8 = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward)
out = fp8_gemm(
weight_fp8,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
ln_out_total,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
activation_dtype,
get_workspace(),
......@@ -678,6 +725,7 @@ class _LayerNormLinear(torch.autograd.Function):
use_bias=use_bias,
)
if is_training:
ctx.save_for_backward(
inputmat,
ln_weight,
......@@ -715,6 +763,7 @@ class _LayerNormLinear(torch.autograd.Function):
return out, ln_out_return.view_as(inp)
return out
@staticmethod
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
......@@ -768,10 +817,12 @@ class _LayerNormLinear(torch.autograd.Function):
# DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
......@@ -804,12 +855,12 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad = fp8_gemm(
ln_out_total_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
],
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
......@@ -890,6 +941,7 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1051,7 +1103,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
if self.parallel_mode == "column":
set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
else:
self.register_buffer("bias", torch.Tensor(), persistent=False)
self.register_buffer("bias", torch.Tensor().type(params_dtype), persistent=False)
with torch.no_grad():
self.bias.zero_()
......@@ -1110,7 +1162,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = bias if bias is not None else self.bias
out = _LayerNormLinear.apply(
if self.training:
fwd_fn = _LayerNormLinear.apply
args = []
else:
fwd_fn = _LayerNormLinear.forward
args = [None]
args += (
inp,
self.layer_norm_weight,
self.layer_norm_bias,
......@@ -1130,7 +1188,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.training,
)
out = fwd_fn(*args)
if self.return_layernorm_output:
out, ln_out = out
......@@ -1146,7 +1206,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
return out, ln_out
return out
class _Linear(torch.autograd.Function):
"""Linear semi-top level module
Calls custom cuda extensions.
......@@ -1170,6 +1229,7 @@ class _Linear(torch.autograd.Function):
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_training: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
......@@ -1186,6 +1246,7 @@ class _Linear(torch.autograd.Function):
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not fp8_meta["recipe"].override_linear_precision.wgrad:
if is_training:
inputmat, inputmat_t = fp8_cast_transpose_fused(
inputmat,
fp8_meta["scaling_fwd"],
......@@ -1199,6 +1260,13 @@ class _Linear(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
inputmat, inputmat_t = cast_to_fp8(
inputmat,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
), None
# Column Parallel Linear
if parallel_mode == "column" and sequence_parallel:
......@@ -1215,6 +1283,7 @@ class _Linear(torch.autograd.Function):
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights:
if is_training:
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
......@@ -1223,13 +1292,23 @@ class _Linear(torch.autograd.Function):
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
else:
weight_t_fp8 = None
weight_fp8 = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)
out = fp8_gemm(
weight_fp8,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
inputmat,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
activation_dtype,
get_workspace(),
......@@ -1251,6 +1330,7 @@ class _Linear(torch.autograd.Function):
use_bias=use_bias,
)
if is_training:
ctx.save_for_backward(
inputmat_no_fp8
if not fp8 or fp8_meta["recipe"].override_linear_precision.wgrad
......@@ -1283,6 +1363,7 @@ class _Linear(torch.autograd.Function):
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
......@@ -1339,10 +1420,12 @@ class _Linear(torch.autograd.Function):
# DGRAD
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
......@@ -1374,12 +1457,12 @@ class _Linear(torch.autograd.Function):
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
wgrad = fp8_gemm(
inputmat_t_total,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
],
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
......@@ -1438,6 +1521,7 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1575,7 +1659,7 @@ class Linear(TransformerEngineBaseModule):
if self.parallel_mode == "column":
set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
else:
self.register_buffer("bias", torch.Tensor(), persistent=False)
self.register_buffer("bias", torch.Tensor().type(params_dtype), persistent=False)
with torch.no_grad():
self.bias.zero_()
......@@ -1629,7 +1713,13 @@ class Linear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = bias if bias is not None else self.bias
out = _Linear.apply(
if self.training:
linear_fn = _Linear.apply
args = []
else:
linear_fn = _Linear.forward
args = [None]
args += (
weight if weight is not None else self.weight,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
......@@ -1645,7 +1735,9 @@ class Linear(TransformerEngineBaseModule):
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.training,
)
out = linear_fn(*args)
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)
......@@ -1687,6 +1779,7 @@ class _LayerNormMLP(torch.autograd.Function):
return_layernorm_output: bool,
bias_gelu_nvfusion: bool,
set_parallel_mode: bool,
is_training: bool
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -1706,6 +1799,7 @@ class _LayerNormMLP(torch.autograd.Function):
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
if is_training:
ln_out, mu, rsigma = layernorm_fwd_fp8(
inputmat,
ln_weight,
......@@ -1715,6 +1809,16 @@ class _LayerNormMLP(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
ln_out = layernorm_fwd_fp8_inf(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps
......@@ -1726,9 +1830,14 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward,
)
else:
if is_training:
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps)
ln_out_return = ln_out
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps
), None, None
ln_out_return = ln_out
# Column Parallel Linear
if set_parallel_mode and sequence_parallel:
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
......@@ -1745,6 +1854,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_bias else fc2_bias
if update_fp8_weights:
if is_training:
fp8_cast_transpose_fused(
fc1_weight,
fp8_meta["scaling_fwd"],
......@@ -1762,13 +1872,30 @@ class _LayerNormMLP(torch.autograd.Function):
cast_out=fc2_weight_fp8,
transpose_out=fc2_weight_t_fp8,
)
else:
fc1_weight_t_fp8 = None
fc1_weight_fp8 = cast_to_fp8(
fc1_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)
fc2_weight_t_fp8 = None
fc2_weight_fp8 = cast_to_fp8(
fc2_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
)
fc1_out = fp8_gemm(
fc1_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
ln_out_total,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
activation_dtype,
get_workspace(),
......@@ -1786,10 +1913,12 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_out = fp8_gemm(
fc2_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT],
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
gelu_out,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM2_INPUT],
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
activation_dtype,
get_workspace(),
......@@ -1816,7 +1945,7 @@ class _LayerNormMLP(torch.autograd.Function):
gelu=not bias_gelu_nvfusion,
)
if bias_gelu_nvfusion:
if bias_gelu_nvfusion and is_training:
fc1_out, _, _ = fc1_outputs
gelu_out = bias_gelu_fused(fc1_out, fc1_bias)
else:
......@@ -1830,7 +1959,7 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc2_bias,
use_bias=use_bias,
)
if is_training:
ctx.save_for_backward(
inputmat,
ln_weight,
......@@ -1873,6 +2002,7 @@ class _LayerNormMLP(torch.autograd.Function):
return fc2_out, ln_out_return.view_as(inp)
return fc2_out
@staticmethod
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
......@@ -1931,10 +2061,12 @@ class _LayerNormMLP(torch.autograd.Function):
# FC2 DGRAD; Unconditional
fc2_dgrad = fp8_gemm(
fc2_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_WEIGHT],
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
......@@ -1947,12 +2079,12 @@ class _LayerNormMLP(torch.autograd.Function):
gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
fc2_wgrad = fp8_gemm(
gelu_out_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_INPUT],
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
],
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
......@@ -2010,10 +2142,12 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 DGRAD: Unconditional
fc1_dgrad = fp8_gemm(
fc1_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
dgelu,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT2],
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
......@@ -2078,12 +2212,12 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
fc1_wgrad = fp8_gemm(
ln_out_total_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
dgelu_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT2
],
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
......@@ -2178,6 +2312,7 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -2372,7 +2507,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
)
)
else:
self.register_buffer("fc2_bias", torch.Tensor(), persistent=False)
self.register_buffer("fc2_bias", torch.Tensor().type(params_dtype), persistent=False)
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
......@@ -2423,7 +2558,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
"""
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
out = _LayerNormMLP.apply(
if self.training:
fwd_fn = _LayerNormMLP.apply
args = []
else:
fwd_fn = _LayerNormMLP.forward
args = [None]
args += (
inp,
self.layer_norm_weight,
self.layer_norm_bias,
......@@ -2448,7 +2589,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.return_layernorm_output,
self.bias_gelu_nvfusion,
self.set_parallel_mode,
self.training,
)
out = fwd_fn(*args)
if self.return_layernorm_output:
out, ln_out = out
......
......@@ -5,9 +5,10 @@
"""Fused scaled masked softmax functions"""
import os
from typing import Callable, Tuple, Union
import torch
from torch import nn
import torch._C._onnx as _C_onnx
from torch.onnx import _type_utils
import transformer_engine_extensions as tex
......@@ -46,6 +47,36 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
return input_grads, None
@staticmethod
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledUpperTriangMaskedSoftmax symbolic method"""
def triangular_mask():
dtype = _type_utils.JitScalarType.INT64
ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype)
k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
mask = g.op("Trilu", ones, k, upper_i=1)
mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL)
return mask
# Captures the logic of function scaled_upper_triang_masked_softmax_warp_forward
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
mask = triangular_mask()
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
inv_mask = g.op("Sub", one, mask)
neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
softmax_mask = g.op("Mul", mask, neg_tenK)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
......@@ -78,6 +109,35 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
)
return input_grads, None, None
@staticmethod
def symbolic(
g: torch.Graph,
inputs: torch._C.Value,
mask: torch._C.Value,
scale: float) -> torch._C.Value:
"""ScaledMaskedSoftmax symbolic method"""
# Captures the logic of function scaled_masked_softmax_warp_forward.
# output = softmax(mask(input*scale)
# Computed as:
# masked_scaled = (1 - mask)*(input*scale)
# softmax_mask = mask * -10000
# output = softmax(masked_scaled + softmax_mask)
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
inv_mask = g.op("Sub", one, mask)
# Note: type is hard coded because softmax uses FP16 or BF16
neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
softmax_mask = g.op("Mul", mask, neg_tenK)
masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
class ScaledSoftmax(torch.autograd.Function):
"""
......@@ -107,6 +167,19 @@ class ScaledSoftmax(torch.autograd.Function):
)
return input_grads, None, None
@staticmethod
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledSoftmax symbolic method"""
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
out = g.op("Softmax", scaled)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
class FusedScaleMaskSoftmax(nn.Module):
"""
......@@ -163,7 +236,7 @@ class FusedScaleMaskSoftmax(nn.Module):
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(sk)
batch_per_block = self.get_batch_per_block(int(sk))
if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0:
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
ONNX symbolic functions for Transformer Engine
Warnings of the type pasted below are a known Pytorch issue
(https://github.com/pytorch/pytorch/issues/81693):
tests/test_onnx_export.py::test_export_cast_ops[112]
/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py:649:
UserWarning: The shape inference of trt::TRT_FP8DequantizeLinear type is missing,
so it may result in wrong shape inference for the exported graph.
Please consider adding it in symbolic function. (Triggered internally at
/opt/pytorch/pytorch/torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1880.)
_C._jit_pass_onnx_graph_shape_type_inference(
Scale tensors are treated as lists ("fs") instead of tensors ("v") because we need to access
specific entries using the index passes as `fp8_tensor`. If you fail to do this you will get
the following error when accessing a sepcific scale element (e.g. `scale_inv[fp8_tensor]`):
TypeError: 'torch._C.Value' object is not subscriptable
"""
import torch
from torch.onnx import symbolic_helper, register_custom_op_symbolic
import torch._C._onnx as _C_onnx
import transformer_engine_extensions as tex
# This file registers custom op symbolic ONNX functions and does not export any symbols.
__all__ = []
# Custom ops spec version
VER = 1
UNSPECIFIED_TYPE = -1
def make_op_name(op_name: str) -> str:
"""custom op name"""
return "trt::" + op_name
def quantize(g, inputs, scale_inv, fp8_tensor):
"""Helper Function for Quantization"""
output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
# Q inputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the input if needed.
if inputs.type().scalarType() == "Half":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT)
scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor]))
q_op = g.op(
make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType(
inputs.type().with_dtype(torch.uint8).with_sizes(output_shape))
return q_op
def dequantize(g, inputs, scale_inv, fp8_tensor, otype):
"""Helper Function for Dequantization"""
output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor]))
out = g.op(make_op_name("TRT_FP8DequantizeLinear"), inputs, scale).setType(
inputs.type().with_dtype(torch.float32).with_sizes(output_shape))
# DQ outputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the output if needed.
if otype == int(tex.DType.kFloat16):
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for cast_to_fp8"""
# pylint: disable=unused-argument
return quantize(g, inputs, scale_inv, fp8_tensor)
@symbolic_helper.parse_args("v", "fs", "i", "i", "i")
def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype):
"""ONNX graph for cast_from_fp8"""
# pylint: disable=unused-argument
return dequantize(g, inputs, scale_inv, fp8_tensor, otype)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_gelu"""
# pylint: disable=unused-argument
gelu = torch.onnx.symbolic_opset9.gelu(g, inputs, "tanh")
out = quantize(g, gelu, scale_inv, fp8_tensor)
return out
@symbolic_helper.parse_args("v", "fs", "i", "i", "i",
"v", "fs", "i", "i", "i",
"v", "i", "v", "v", "i",
"v", "i", "i", "i")
def onnx_te_gemm(
g,
weight,
weight_scale_inverse,
weight_fp8_tensor,
weight_type,
trans_weight,
inputs,
input_scale_inverse,
input_fp8_tensor,
input_type,
trans_input,
out,
out_type,
bias,
pre_gelu_out,
grad,
workspace,
workspaceSize,
accumulate,
use_split_accumulator):
"""ONNX graph for te_gemm"""
# pylint: disable=unused-argument
is_fp16 = bias.type().scalarType() == "Half"
if input_type == int(tex.DType.kFloat8E4M3):
inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, UNSPECIFIED_TYPE)
if weight_type == int(tex.DType.kFloat8E4M3):
weight = dequantize(g, weight, weight_scale_inverse, weight_fp8_tensor, UNSPECIFIED_TYPE)
output = g.op("Gemm", inputs, weight, transA_i=trans_input, transB_i=trans_weight)
empty_tensor_size = [0]
bias_empty = torch.onnx.symbolic_helper._get_tensor_sizes(bias) == empty_tensor_size
pre_gelu_out_empty = torch.onnx.symbolic_helper._get_tensor_sizes(pre_gelu_out) \
== empty_tensor_size
if not bias_empty:
if pre_gelu_out_empty:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
output = g.op('Add', output, bias)
else:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
output = g.op('Add', output, bias)
output = torch.onnx.symbolic_opset9.gelu(g, output)
else:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return output
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i")
def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for layernorm_fwd_fp8"""
# pylint: disable=unused-argument
ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
return fp8_ln
@symbolic_helper.parse_args("v", "v", "v", "f")
def onnx_layernorm_fwd(g, inputs, weight, bias, eps):
"""ONNX graph for layernorm_fwd"""
# pylint: disable=unused-argument
normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
if normalized_shape is None:
ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs)
assert ndim is not None
normalized_shape = list(range(0, ndim))
# Normalization axis = 0, so normalized_shape uses all dims except dim = 0
normalized_shape = normalized_shape[1:]
ln = torch.onnx.symbolic_opset9.layer_norm(
g,
inputs,
normalized_shape,
weight,
bias,
eps,
False # cudnn_enable (not relevant)
)
return ln
register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER)
register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER)
register_custom_op_symbolic('tex_ts::fp8_gelu_ts', onnx_fp8_gelu, VER)
register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment