Unverified Commit 6c9ce179 authored by asfiyab-nvidia's avatar asfiyab-nvidia Committed by GitHub
Browse files

Add ONNX export support for TE modules (#41)



* Add ONNX export support for TE modules (#1)

* Add TorchScript Operators
* Add symbolic methods to ONNX exporter
* Add tests for the ONNX export
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fixes for pylint tests
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix pylint warning in softmax.py
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* move FP8 ORT lib inside tests/
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* enable cross attention tests
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* refactor code by @nzmora
* Increase layernorm FP16 threshold
* Normalize onnx file names: _ separates configs; - separates words in a single config
* Add get_attn_mask_str and fix mask string
* Add missing ONNX files
* Moved generated ONNX files to tests/gen_onnx_models/
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix merge conflict changes
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix Q/DQ scale input
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* enable FP16 config when bias is disabled
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix pylint check errors
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* updates
1. remove List import for pylint failure
2. address comments: remove state tensors from GPU
3. address comments: Update reverse_map_dtype function and add to namespace
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* minor fix: coding guidelines
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* changes:
1. skip FP8 tests on  non-hopper devices
2. minor fix for C++ lint check
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix onnxruntime version
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* minor fix: add space between code and comment
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* changes
1. update copyrights
2. update path to ORT .so
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarasfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com>
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
Signed-off-by: default avatarasfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e2ad34e9
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
*.nsys-rep *.nsys-rep
*.ncu-rep *.ncu-rep
*.sqlite *.sqlite
*.onnx
.eggs .eggs
build/ build/
*.so *.so
......
...@@ -6,5 +6,6 @@ set -e ...@@ -6,5 +6,6 @@ set -e
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/test_transformerengine.py pytest -v -s $TE_PATH/tests/test_transformerengine.py
pytest -v -s $TE_PATH/tests/test_onnx_export.py
...@@ -14,7 +14,7 @@ from setuptools import setup, find_packages, Extension ...@@ -14,7 +14,7 @@ from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
from distutils.version import LooseVersion from distutils.version import LooseVersion
from distutils.file_util import copy_file from distutils.file_util import copy_file
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
...@@ -85,6 +85,7 @@ include_dirs = make_abs_path(include_dirs) ...@@ -85,6 +85,7 @@ include_dirs = make_abs_path(include_dirs)
pytorch_sources = [ pytorch_sources = [
"transformer_engine/pytorch/csrc/extensions.cu", "transformer_engine/pytorch/csrc/extensions.cu",
"transformer_engine/pytorch/csrc/common.cu", "transformer_engine/pytorch/csrc/common.cu",
"transformer_engine/pytorch/csrc/ts_fp8_op.cpp",
] ]
pytorch_sources = make_abs_path(pytorch_sources) pytorch_sources = make_abs_path(pytorch_sources)
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
This file contains tests for exporting TransformerEngine models to ONNX.
"""
import os
import pytest
import warnings
import numpy as np
import math
import onnxruntime as ort
import torch
from torch import nn as nn
from typing import Union, Tuple
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, fp8_gelu, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.module import get_workspace
import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs
from transformer_engine.pytorch.utils import get_default_init_method
# Directory where generated ONNX test models are stored.
TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
ONNX_FILES_DIR = os.path.join(TESTS_DIR, "./gen_onnx_models")
# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT).
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
# ScaledUpperTriangMaskedSoftmax is exported via ONNX::Trilu which was introduced in opset 14.
TRILU_OPSET = 14
# Opset used in the ONNX files generated by the tests.
OPSET = 15
assert OPSET >= TRILU_OPSET
skip_FP8 = pytest.mark.skipif(
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
reason="Device compute capability 9.x required for FP8 execution.",
)
def create_fp8_recipe():
return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
def do_export(
model: torch.nn.Module,
inp: torch.Tensor,
fname: str,
use_fp8: bool=True,
opset: int=OPSET,
input_names: list=["input"],
output_names: list=["output"],
):
"""Export to ONNX"""
fp8_recipe = create_fp8_recipe()
with torch.inference_mode(), te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings():
warnings.filterwarnings(
action='ignore',
category=torch.jit.TracerWarning,
module=r'.*'
)
model.cuda().eval()
os.makedirs(ONNX_FILES_DIR, exist_ok=True)
fname = os.path.join(ONNX_FILES_DIR, fname)
torch.onnx.export(model,
inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,),
fname,
verbose=False,
opset_version=opset,
input_names=input_names,
output_names=output_names,
do_constant_folding=True,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH)
def to_numpy(tensor):
return tensor.cpu().numpy()
def set_layer_scale(module: torch.nn.Module, scale: float):
module.fp8_init()
module.fp8_meta["scaling_fwd"].scale = torch.ones(
2, dtype=torch.float32, device="cuda") / scale
module.fp8_meta["scaling_fwd"].scale_inv = torch.ones(
2, dtype=torch.float32, device="cuda") * scale
def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool):
"""Transformer Engine forward prpoagtation.
Return results after copying to the CPU and converting to numpy.
"""
fp8_recipe = create_fp8_recipe()
with torch.inference_mode(), te.fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings():
te_outputs = model(*inps if isinstance(inps, tuple) else (inps,))
if not isinstance(te_outputs, tuple):
te_outputs = (te_outputs,)
te_outputs_np = [to_numpy(te_output) for te_output in te_outputs]
return te_outputs_np
def validate_result(
fname: str,
inps: Union[Tuple[torch.Tensor], torch.Tensor],
model: torch.nn.Module,
atol: float=1.e-8, # np.isclose default atol
rtol: float=1.e-5, # np.isclose default rtol
max_errors_printed: int=10,
is_fp8: bool=False,
):
"""Validate the outputs of an ONNX model vs. ONNX Runtime."""
def create_ort_session(fname: str, is_fp8: bool):
def load_custom_ops(session_opts: ort.SessionOptions):
"""For FP8 validation with ORT we need to load our custom FP8 Q/DQ extension."""
if not os.path.exists(ORT_CUSTOM_OPS_LIB):
raise FileNotFoundError(f"Unable to find {ORT_CUSTOM_OPS_LIB}")
session_opts.register_custom_ops_library(ORT_CUSTOM_OPS_LIB)
print("registered custom FP8 Q/DQ ops!")
"""Create an ONNX Runtime session for validation."""
if is_fp8:
sess_options = ort.SessionOptions()
load_custom_ops(sess_options)
# Model loading successfully indicates that the custom op node could be resolved successfully
s = ort.InferenceSession(fname, sess_options=sess_options)
else:
s = ort.InferenceSession(fname)
return s
def create_ort_input_dict(session, inps):
inp_dict = {}
if isinstance(inps, tuple) or isinstance(inps, list):
nonetype_inputs = 0
for idx, inp in enumerate(inps):
if inp is None:
nonetype_inputs += 1
continue
inp_dict[session.get_inputs()[idx - nonetype_inputs].name] = to_numpy(inp)
else:
inp_dict[session.get_inputs()[0].name] = to_numpy(inps)
return inp_dict
# Run ORT session and TE model.
fname = os.path.join(ONNX_FILES_DIR, fname)
ort_s = create_ort_session(fname, is_fp8)
onnx_outputs = ort_s.run(None, input_feed=create_ort_input_dict(ort_s, inps))
te_outputs = te_infer(model, inps, is_fp8)
# Compare ORT and TE outputs.
assert len(onnx_outputs) == len(te_outputs)
for onnx_output, te_output in zip(onnx_outputs, te_outputs):
# Compare ORT and PyTorch outputs.
# np.isclose: abs(a - b) <= (atol + rtol * abs(b))
ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol)
mismatches = ac.nonzero()
mismatched_ids = [loc for loc in zip(*mismatches)]
if mismatched_ids:
# Log some information in case of error.
print("*" * 100)
print(onnx_output.shape)
nb_vals = min(len(mismatched_ids), max_errors_printed)
print(f"Detected {len(mismatched_ids)} diverging values.\nShowing first {nb_vals} errors (ONNX -- TE):")
abs_err = abs(onnx_output - te_output)
for loc in mismatched_ids[:nb_vals]:
ref = te_output[loc]
print(f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} > {atol + rtol * abs(ref)}")
raise ValueError(f"Output validation of {fname} failed with {len(mismatched_ids)} errors")
def create_meta(scale_factor: float, size: int=1):
meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
return meta
def dtype2str(dtype: torch.dtype):
return {
torch.float32: "_fp32",
torch.float16: "_fp16",
torch.bfloat16: "_bf16",
}[dtype]
def as_te_type(dtype: torch.dtype):
return {
torch.float32: tex.DType.kFloat32,
torch.float16: tex.DType.kFloat16,
torch.bfloat16: tex.DType.kBFloat16,
}[dtype]
def get_attn_mask_str(use_mask, attn_mask_type):
# See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names.
if attn_mask_type is None:
return "_mask" if use_mask else "_no-mask"
attn_mask_str = "_padding-no-mask"
attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str
attn_mask_str = "_padding-mask" if use_mask and attn_mask_type == "padding" else attn_mask_str
return attn_mask_str
@skip_FP8
@pytest.mark.parametrize("scale_factor, atol", [
(1, 1e-7),
(224, 1e-7)
])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
def test_export_cast_ops(scale_factor: float, atol: float, precision: torch.dtype):
class TestFP8_QDQ(nn.Module):
def __init__(self):
super().__init__()
self.fp8_tensor = 0
self.meta = create_meta(scale_factor)
self.highprec_type = as_te_type(precision)
self.fp8_type = tex.DType.kFloat8E4M3
def forward(self, inp):
ret = cast_to_fp8(
inp,
self.meta,
self.fp8_tensor,
self.fp8_type)
ret = cast_from_fp8(
ret,
self.meta,
self.fp8_tensor,
self.fp8_type,
self.highprec_type)
return ret
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision)
high_prec_str = dtype2str(precision)
fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_QDQ()
do_export(model, inp, fname)
validate_result(fname, inp, model, atol=atol, is_fp8=True)
@skip_FP8
@pytest.mark.parametrize("scale_factor", [448])
@pytest.mark.parametrize(
"precision, atol", [
[torch.float32, 1e-7],
[torch.float16, 2e-3]
])
def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float):
class TestFP8_Gelu(nn.Module):
def __init__(self):
super().__init__()
self.fp8_tensor = 0
self.meta = create_meta(scale_factor)
self.highprec_type = as_te_type(precision)
self.fp8_type = tex.DType.kFloat8E4M3
def forward(self, inp):
ret = fp8_gelu(
inp,
self.meta,
self.fp8_tensor,
self.fp8_type)
ret = cast_from_fp8(
ret,
self.meta,
self.fp8_tensor,
self.fp8_type,
self.highprec_type)
return ret
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision)
high_prec_str = dtype2str(precision)
fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_Gelu()
do_export(model, inp, fname)
validate_result(fname, inp, model, rtol=1e-1, atol=atol, is_fp8=True)
@pytest.mark.parametrize("scale_factors",
[(224, 224,),
])
@pytest.mark.parametrize(
"precision, use_fp8, use_bias, use_gelu", [
(torch.float32, False, False, False),
(torch.float16, False, False, False),
(torch.float32, False, True, False),
(torch.float16, False, True, False),
(torch.float32, False, True, True),
(torch.float16, False, True, True),
# For FP8 GEMM GeLU is not used.
(torch.float32, True, False, False),
(torch.float16, True, False, False),
# When enabling bias we must use float16 or bfloat16 (because of kernel limitations)
(torch.float16, True, True, False),
(torch.bfloat16, True, True, False),
])
def test_export_gemm(
precision, # Precision of inputs, weights, output and bias
use_fp8,
use_bias,
use_gelu,
scale_factors
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
class TestFP8_GEMM(nn.Module):
def __init__(self, precision, use_bias, gelu, scale_factors):
super().__init__()
self.use_bias = use_bias
self.gelu = gelu
self.precision = precision
self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
nb_inp_scales, nb_weight_scales = 1, out_features
act_scale_factor, weight_scale_factor = scale_factors
self.meta_inp = create_meta(act_scale_factor, nb_inp_scales)
self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales)
bias_size = nb_weight_scales
self.bias = torch.randn(bias_size, dtype=precision, device="cuda")
self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda")
self.inp_type = tex.DType.kFloat8E4M3
self.weights_type = tex.DType.kFloat8E4M3
self.outp_type = precision
def forward(self, inp, weight):
inp_fp8 = cast_to_fp8(
inp,
self.meta_inp,
self.fp8_tensor_inp,
self.inp_type)
weight_fp8 = cast_to_fp8(
weight,
self.meta_weight,
self.fp8_tensor_weight,
self.weights_type)
ret = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
self.inp_type,
inp_fp8,
self.meta_inp.scale_inv,
self.fp8_tensor_inp,
self.weights_type,
self.outp_type,
get_workspace(),
bias=self.bias,
use_bias=self.use_bias,
fp32_output=(self.precision==torch.float32),
use_split_accumulator=False)
return ret
class Test_GEMM(nn.Module):
def __init__(self, precision, use_bias=False, gelu=False):
super().__init__()
self.use_bias = use_bias
self.gelu = gelu
self.precision = precision
bias_size = out_features
self.bias = torch.randn(bias_size, dtype=precision, device="cuda")
self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda")
def forward(self, inp, weight):
outp_type = self.precision
# note: due to logic in lines 104:116 and L129 in cpp_extensions.py
# it appears either bias OR gelu can be activated, not both
ret, _, _ = gemm(
weight,
inp,
outp_type,
get_workspace(),
# test bias
bias=self.bias,
use_bias=self.use_bias,
# test gelu
gelu=self.gelu,
gelu_input=self.gelu_input,
grad=False # only True for backward pass
)
return ret
# If gelu is applied then bias must be added, as defined by TE kernel.
if use_gelu: assert use_bias
# Set dimensions (these are arbitrary).
out_features = 128
hidden_size = 256
in_features = 64
inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda")
weight = torch.randn(out_features, in_features, dtype=precision, device="cuda")
fp8_str = "_fp8" if use_fp8 else ""
bias_str = "_bias" if use_bias else ""
gelu_str = "_gelu" if use_gelu else ""
high_prec_str = dtype2str(precision)
fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx"
if use_fp8:
model = TestFP8_GEMM(precision, use_bias, use_gelu, scale_factors)
do_export(model, (inp, weight), fname, use_fp8)
if precision not in (torch.bfloat16, torch.float16):
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=1e-2, is_fp8=True)
else:
model = Test_GEMM(precision, use_bias, use_gelu)
do_export(model, (inp, weight), fname, use_fp8)
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2)
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
def test_export_layernorm(
use_fp8: bool,
scale_factor: float,
precision: torch.dtype
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
# Set dimensions (these are arbitrary).
inp_shape = [64, 32]
class Test_Layernorm(nn.Module):
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda")
self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda")
self.eps = 1e-6 # An arbitrary small value
def forward(self, inp):
ret = texcpp.layernorm_fwd_inf(
inp,
self.weight,
self.bias,
self.eps)
return ret
class TestFP8_Layernorm(nn.Module):
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda")
self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda")
self.eps = 1e-6 # An arbitrary small value
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(scale_factor)
self.fp8_type = tex.DType.kFloat8E4M3
def forward(self, inp):
ret = texcpp.layernorm_fwd_fp8_inf(
inp,
self.weight,
self.bias,
self.eps,
self.meta,
self.fp8_tensor,
self.fp8_type)
ret = cast_from_fp8(
ret,
self.meta,
self.fp8_tensor,
self.fp8_type,
tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16)
return ret
inp = torch.randn(*inp_shape, device="cuda", dtype=precision)
model = TestFP8_Layernorm() if use_fp8 else Test_Layernorm()
high_prec_str = dtype2str(precision)
fp8_str = f"_fp8-{scale_factor}" if use_fp8 else ""
fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx"
do_export(model, inp, fname, use_fp8=use_fp8)
if precision not in (torch.bfloat16, ):
# TODO: FP32 has a small threshold (1e-5)
validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8)
@skip_FP8
@pytest.mark.parametrize("softmax_def", [
softmax_defs.ScaledUpperTriangMaskedSoftmax,
softmax_defs.ScaledMaskedSoftmax,
softmax_defs.ScaledSoftmax,
])
# Softmax kernel only supports FP16 or BF16!
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16])
def test_export_softmax(softmax_def, precision):
class Test_Softmax(nn.Module):
def __init__(self, softmax_function, mask_inp=False):
super().__init__()
self.softmax_fn = softmax_function
self.mask_inp = mask_inp
def forward(self, inp, mask):
scale_factor = 8 # arbitrary value
if self.mask_inp:
ret = self.softmax_fn.apply(inp, mask, scale_factor)
else:
ret = self.softmax_fn.apply(inp, scale_factor)
return ret
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
mask = None
input_names = ["input"]
inp_shape = [hidden_size, in_features, in_features, in_features]
if softmax_def == softmax_defs.ScaledUpperTriangMaskedSoftmax:
inp_shape = [hidden_size, in_features, in_features]
kernel_str = "ScaledUpperTriangMaskedSoftmax"
model = Test_Softmax(softmax_def)
elif softmax_def == softmax_defs.ScaledMaskedSoftmax:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(hidden_size, 1, in_features, in_features, device="cuda", dtype=precision)
mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
input_names.append("mask")
kernel_str = "ScaledMaskedSoftmax"
model = Test_Softmax(softmax_def, mask_inp=True)
elif softmax_def == softmax_defs.ScaledSoftmax:
kernel_str = "ScaledSoftmax"
model = Test_Softmax(softmax_def)
input_tensor = torch.randn(*inp_shape, device="cuda")
input_tensor = input_tensor.to(torch.bfloat16) if precision == torch.bfloat16 else input_tensor.half()
high_prec_str = dtype2str(precision)
fname = f"{kernel_str}{high_prec_str}.onnx"
inp = (input_tensor, mask)
do_export(model, inp, fname, input_names=input_names)
if precision != torch.bfloat16:
validate_result(fname, inp, model, atol=1e-3)
@pytest.mark.parametrize("scale_factor", [1])
@pytest.mark.parametrize("use_fp8", [False, True])
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize(
"precision, use_bias",[
(torch.float32, False),
(torch.float32, True),
(torch.float16, False),
(torch.float16, True),
# Todo: cannot configure BF16 when bias is disabled (ORT issue?)
(torch.bfloat16, False),
# Todo: cannot configure BF16 when bias is enabled (ORT issue?)
# (torch.bfloat16, True),
])
def test_export_linear(
scale_factor: float,
use_fp8: bool,
use_bias: bool,
return_bias: bool,
precision: torch.dtype
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
# Set dimensions (these are arbitrary).
in_features = 64
out_features = 256
hidden_size = 256
class Test_Linear(nn.Module):
def __init__(self,
in_features,
out_features,
use_bias,
return_bias,
precision
):
super().__init__()
self.linear = te.Linear(
in_features,
out_features,
bias=use_bias,
return_bias=return_bias,
params_dtype=precision
)
def forward(self, inp):
ret = self.linear(inp)
return ret
inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if use_fp8 else ""
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=use_fp8):
model = Test_Linear(
in_features,
out_features,
use_bias,
return_bias,
precision
).to(device='cuda')
if use_fp8:
set_layer_scale(model.linear, scale_factor)
do_export(model, inp, fname, use_fp8)
if precision in (torch.bfloat16, ):
return
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
else:
validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("use_fp8", [False, True])
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize(
"precision, use_bias",[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
])
def test_export_layernorm_linear(
scale_factor: float,
use_fp8: bool,
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
# Set dimensions (these are arbitrary).
in_features = 64
out_features = 256
hidden_size = 256
inp = torch.randn(in_features, out_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if use_fp8 else ""
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=use_fp8):
model = te.LayerNormLinear(
hidden_size,
3 * hidden_size,
bias=use_bias,
return_bias=return_bias,
return_layernorm_output=return_layernorm_output,
params_dtype=precision,
).to(device='cuda')
if use_fp8:
set_layer_scale(model, scale_factor)
do_export(model, inp, fname, use_fp8)
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
elif precision not in (torch.bfloat16,):
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("use_fp8", [False, True])
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize(
"precision, use_bias",[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
])
def test_export_layernorm_mlp(
scale_factor: float,
use_fp8: bool,
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
# Set dimensions (these are arbitrary).
in_features = 64
out_features = 256
hidden_size = 256
ffn_hidden_size = 256
inp = torch.randn(in_features, out_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if use_fp8 else ""
bias_str = "_bias" if use_bias else ""
high_prec_str = dtype2str(precision)
fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=use_fp8):
model = te.LayerNormMLP(
hidden_size,
ffn_hidden_size,
bias=use_bias,
return_bias=return_bias,
return_layernorm_output=return_layernorm_output,
params_dtype=precision,
).to(device='cuda')
if use_fp8:
set_layer_scale(model, scale_factor)
do_export(model, inp, fname, use_fp8)
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
else:
validate_result(fname, inp, model, atol=2e-2, is_fp8=use_fp8)
@skip_FP8
@pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", [
(torch.float32, False, None), # calls forward_torch_softmax
(torch.float32, True, None), # calls forward_torch_softmax
(torch.float16, False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(torch.float16, True, "padding"), # calls ScaledMaskedSoftmax
(torch.float16, False, "padding"), # calls ScaledSoftmax
])
@pytest.mark.parametrize("attention_softmax_in_fp32",
[True, False])
@pytest.mark.parametrize("apply_query_key_layer_scaling",
[True, False])
def test_export_core_attention(
precision: torch.dtype,
use_mask: bool,
attn_mask_type: str,
attention_softmax_in_fp32: bool,
apply_query_key_layer_scaling: bool,
):
# Set dimensions (these are arbitrary).
kv_channels = 64
num_attention_heads = 1
qkv_size = (2048, 4, num_attention_heads, kv_channels)
query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
value_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
input_names = ["query", "key", "value"]
attention_mask = None
if use_mask:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
input_names.append("attention_mask")
inp = (query_layer, key_layer, value_layer, attention_mask)
sm_prec_str = "_sm-fp32" if attention_softmax_in_fp32 else "_sm-fp16"
qk_scaling_str = "_qk-scaling" if apply_query_key_layer_scaling else ""
mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision)
fname = f"te.core_attention{mask_str}{qk_scaling_str}{sm_prec_str}{high_prec_str}.onnx"
if attn_mask_type is None:
attn_mask_type = 'causal'
model = te.transformer.CoreAttention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
attention_dropout=0.5,
attn_mask_type=attn_mask_type,
attention_softmax_in_fp32=attention_softmax_in_fp32,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
).to(device='cuda')
do_export(model,
inp,
fname,
input_names=input_names,
use_fp8=True)
validate_result(fname, inp, model, atol=1e-2)
test_configs_multihead_attention = [
#"use_mask, attn_mask_type"
(False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(True, "padding"), # calls ScaledMaskedSoftmax
(False, "padding"), # calls ScaledSoftmax
]
test_configs_attention_type = [
#"input_layernorm, attention_type, fuse_qkv_params"
(True, "self", True),
(False, "self", True),
(True, "self", False),
(False, "self", False),
# disabled because query_bias (reqd for cross attention) is defined when fuse_qkv_params is False
# (True, "cross", True),
# (False, "cross", True),
(True, "cross", False),
# disabled because TypeError: cannot assign 'transformer_engine.pytorch.module.Linear'
# as parameter 'query' (torch.nn.Parameter or None expected)
# (False, "cross", False),
]
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize("input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type)
def test_export_multihead_attention(
use_fp8: bool,
use_mask: bool,
attn_mask_type: str,
precision: torch.dtype,
return_layernorm_output: bool,
input_layernorm: bool,
attention_type: str,
fuse_qkv_params: bool
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
hidden_size = 256
sequence_length = 128
batch_size = 4
num_attention_heads = 32
kv_channels = 8
attention_dropout = 0.1
layernorm_epsilon = 1e-5
init_method = output_layer_init_method = get_default_init_method()
attention_args = (
hidden_size,
num_attention_heads,
kv_channels,
attention_dropout,
layernorm_epsilon,
init_method,
output_layer_init_method,
)
hidden_states = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
encoder_output = None
if attention_type == "cross":
encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
inp = (hidden_states, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"]
fp8_str = "_fp8" if use_fp8 else ""
dtype_str = dtype2str(precision)
attn_type_str = "_self-attention" if attention_type == "self" else "_cross-attention"
fuse_qkv_str = "_fused-qkv" if fuse_qkv_params else ""
attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
input_ln_str = "_input-ln" if input_layernorm else ""
fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx"
model = te.transformer.MultiHeadAttention(
*attention_args,
attn_mask_type=attn_mask_type,
params_dtype=precision,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params,
).to(device='cuda')
do_export(model, inp, fname, use_fp8, input_names=input_names)
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
elif precision != torch.float16:
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8)
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("output_layernorm", [
#True, # TO DO: handle this
False
])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("apply_query_key_layer_scaling", [True, False])
def test_export_transformer_layer(
use_fp8: bool,
use_mask: bool,
attn_mask_type: str,
output_layernorm: bool,
precision: torch.dtype,
fuse_qkv_params: bool,
apply_query_key_layer_scaling: bool
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
# Layer configuration
hidden_size = 64
sequence_length = 128
batch_size = 1
ffn_hidden_size = 256
num_attention_heads = 4
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
input_names = ["input"]
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
input_names.append("attention_mask")
inp = (input_tensor, attention_mask)
fp8_str = "_fp8" if use_fp8 else ""
fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
qk_scaling_str = "_qk-scaling" if apply_query_key_layer_scaling else ""
high_prec_str = dtype2str(precision)
attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{qk_scaling_str}{high_prec_str}.onnx"
model = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
apply_query_key_layer_scaling=apply_query_key_layer_scaling).to(device='cuda')
do_export(model, inp, fname, use_fp8)
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
elif precision != torch.float16:
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8)
...@@ -10,3 +10,12 @@ from .module import LayerNorm ...@@ -10,3 +10,12 @@ from .module import LayerNorm
from .transformer import TransformerLayer from .transformer import TransformerLayer
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
from .distributed import checkpoint from .distributed import checkpoint
# Register custom op symbolic ONNX functions
from .te_onnx_extensions import (
onnx_cast_to_fp8,
onnx_cast_from_fp8,
onnx_fp8_gelu,
onnx_te_gemm,
onnx_layernorm_fwd_fp8,
onnx_layernorm_fwd,
)
...@@ -12,9 +12,11 @@ from .constants import TE_DType ...@@ -12,9 +12,11 @@ from .constants import TE_DType
def fp8_gemm( def fp8_gemm(
A: torch.Tensor, A: torch.Tensor,
A_scale_inv: torch.Tensor, A_scale_inv: torch.Tensor,
A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
A_dtype: tex.DType, A_dtype: tex.DType,
B: torch.Tensor, B: torch.Tensor,
B_scale_inv: torch.Tensor, B_scale_inv: torch.Tensor,
B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
B_dtype: tex.DType, B_dtype: tex.DType,
out_dtype: torch.dtype, out_dtype: torch.dtype,
workspace: torch.Tensor, workspace: torch.Tensor,
...@@ -41,19 +43,21 @@ def fp8_gemm( ...@@ -41,19 +43,21 @@ def fp8_gemm(
out_dtype = tex.DType.kFloat32 if fp32_output else TE_DType[out_dtype] out_dtype = tex.DType.kFloat32 if fp32_output else TE_DType[out_dtype]
tex.te_gemm( _ = torch.ops.tex_ts.te_gemm_ts(
A, A,
A_scale_inv, A_scale_inv,
A_fp8_tensor,
A_dtype, A_dtype,
True, # transa True, # transa
B, B,
B_scale_inv, B_scale_inv,
B_fp8_tensor,
B_dtype, B_dtype,
False, # transb False, # transb
out, out,
out_dtype, out_dtype,
bias if use_bias else empty_tensor, bias if use_bias else empty_tensor,
empty_tensor, empty_tensor, # this is pre_gelu_out
False, # grad False, # grad
workspace, workspace,
workspace.shape[0], workspace.shape[0],
...@@ -87,6 +91,7 @@ def gemm( ...@@ -87,6 +91,7 @@ def gemm(
transa = layout[0] == "T" transa = layout[0] == "T"
transb = layout[1] == "T" transb = layout[1] == "T"
empty_tensor = torch.Tensor() empty_tensor = torch.Tensor()
fp8_index = -1 # dummy index
input_dtype = TE_DType[dtype] input_dtype = TE_DType[dtype]
output_dtype = tex.DType.kFloat32 if fp32_output else input_dtype output_dtype = tex.DType.kFloat32 if fp32_output else input_dtype
...@@ -115,13 +120,15 @@ def gemm( ...@@ -115,13 +120,15 @@ def gemm(
bias = bias if use_bias else empty_tensor bias = bias if use_bias else empty_tensor
tex.te_gemm( _ = torch.ops.tex_ts.te_gemm_ts(
A, A,
empty_tensor, empty_tensor,
fp8_index,
input_dtype, input_dtype,
transa, transa,
B, B,
empty_tensor, empty_tensor,
fp8_index,
input_dtype, input_dtype,
transb, transb,
out, out,
...@@ -214,11 +221,12 @@ def fp8_gelu( ...@@ -214,11 +221,12 @@ def fp8_gelu(
otype: tex.DType, otype: tex.DType,
) -> torch.Tensor: ) -> torch.Tensor:
"""GeLU with FP8 output""" """GeLU with FP8 output"""
return tex.fp8_gelu( return torch.ops.tex_ts.fp8_gelu_ts(
inp, inp,
fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
fp8_tensor,
otype, otype,
) )
...@@ -245,6 +253,48 @@ def layernorm_fwd_fp8( ...@@ -245,6 +253,48 @@ def layernorm_fwd_fp8(
) )
def layernorm_fwd_fp8_inf(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""LayerNorm with FP8 output.
This version of layernorm_fwd_fp8 is specialized for inference, and returns
only the normalized output.
"""
ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype)
return ret
def layernorm_fwd_inf(
inp: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
) -> torch.Tensor:
"""LayerNorm with FP8 output"""
return torch.ops.tex_ts.layernorm_fwd_inf_ts(
inp,
weight,
bias,
eps,
)
def cast_to_fp8( def cast_to_fp8(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: tex.FP8TensorMeta,
...@@ -252,11 +302,12 @@ def cast_to_fp8( ...@@ -252,11 +302,12 @@ def cast_to_fp8(
otype: tex.DType, otype: tex.DType,
) -> torch.Tensor: ) -> torch.Tensor:
"""Cast input to FP8""" """Cast input to FP8"""
return tex.cast_to_fp8( return torch.ops.tex_ts.cast_to_fp8_ts(
inp, inp,
fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
fp8_tensor,
otype, otype,
) )
...@@ -269,9 +320,10 @@ def cast_from_fp8( ...@@ -269,9 +320,10 @@ def cast_from_fp8(
otype: tex.DType, otype: tex.DType,
) -> torch.Tensor: ) -> torch.Tensor:
"""Cast input from FP8""" """Cast input from FP8"""
return tex.cast_from_fp8( return torch.ops.tex_ts.cast_from_fp8_ts(
inp, inp,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
fp8_tensor,
itype, itype,
otype, otype,
) )
...@@ -94,6 +94,8 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { ...@@ -94,6 +94,8 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
return transformer_engine::DType::kFloat32; return transformer_engine::DType::kFloat32;
case at::kBFloat16: case at::kBFloat16:
return transformer_engine::DType::kBFloat16; return transformer_engine::DType::kBFloat16;
case at::kBool:
return transformer_engine::DType::kByte;
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
......
...@@ -397,6 +397,23 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -397,6 +397,23 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
} }
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype);
return out[0];
}
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
...@@ -428,6 +445,16 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, ...@@ -428,6 +445,16 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
return {ln_out, mu, rsigma}; return {ln_out, mu, rsigma};
} }
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps
) {
// This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps);
return out[0];
}
at::Tensor cast_to_fp8(const at::Tensor &input, at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale, const at::Tensor &scale,
......
...@@ -95,6 +95,15 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -95,6 +95,15 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
transformer_engine::DType otype transformer_engine::DType otype
); );
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
...@@ -102,6 +111,11 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, ...@@ -102,6 +111,11 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
float eps float eps
); );
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps
);
at::Tensor cast_to_fp8(const at::Tensor &input, at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale, const at::Tensor &scale,
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <torch/script.h>
#include "extensions.h"
namespace {
transformer_engine::DType reverse_map_dtype(int64_t dtype) {
if (dtype >= 0 && dtype < static_cast<int64_t>(transformer_engine::DType::kNumTypes)) {
return static_cast<transformer_engine::DType>(dtype);
} else {
NVTE_ERROR("Type not supported.");
}
}
} // namespace
at::Tensor cast_to_fp8_ts(const at::Tensor &input,
const at::Tensor &scale,
const at::Tensor &amax,
const at::Tensor &scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = cast_to_fp8(input,
scale[fp8_tensor],
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
otype_arg);
return output.clone();
}
at::Tensor cast_from_fp8_ts(const at::Tensor &input,
const at::Tensor &scale_inv,
int64_t fp8_tensor,
int64_t itype,
int64_t otype) {
transformer_engine::DType itype_arg = reverse_map_dtype(itype);
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = cast_from_fp8(input,
scale_inv[fp8_tensor],
itype_arg,
otype_arg);
return output.clone();
}
at::Tensor fp8_gelu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = fp8_gelu(input,
scale[fp8_tensor],
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
otype_arg);
return output.clone();
}
at::Tensor te_gemm_ts(at::Tensor A,
at::Tensor A_scale_inverse,
int64_t A_fp8_tensor,
int64_t A_type,
int64_t transa,
at::Tensor B,
at::Tensor B_scale_inverse,
int64_t B_fp8_tensor,
int64_t B_type,
int64_t transb,
at::Tensor D,
int64_t D_type,
at::Tensor bias,
at::Tensor pre_gelu_out,
int64_t grad,
at::Tensor workspace,
int64_t workspaceSize,
int64_t accumulate,
int64_t use_split_accumulator) {
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);
at::Tensor A_scale_inverse_arg = A_scale_inverse.clone();
if (A_scale_inverse.numel())
A_scale_inverse_arg = A_scale_inverse[A_fp8_tensor];
at::Tensor B_scale_inverse_arg = B_scale_inverse.clone();
if (B_scale_inverse.numel())
B_scale_inverse_arg = B_scale_inverse[B_fp8_tensor];
te_gemm(A,
A_scale_inverse_arg,
A_type_arg,
transa_arg,
B,
B_scale_inverse_arg,
B_type_arg,
transb_arg,
D,
D_type_arg,
bias,
pre_gelu_out,
grad_arg,
workspace,
workspaceSize_arg,
accumulate_arg,
use_split_accumulator_arg);
return D;
}
at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
double eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps);
at::Tensor output = layernorm_fwd_fp8_inf(input,
weight,
bias,
eps_float,
scale,
amax,
scale_inv,
otype_arg);
return output.clone();
}
at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
double eps) {
float eps_float = static_cast<float>(eps);
at::Tensor output = layernorm_fwd_inf(input,
weight,
bias,
eps_float);
return output.clone();
}
TORCH_LIBRARY(tex_ts, m) {
m.def("cast_to_fp8_ts", &cast_to_fp8_ts);
m.def("cast_from_fp8_ts", &cast_from_fp8_ts);
m.def("fp8_gelu_ts", &fp8_gelu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
}
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
"""Top level Transformer Engine PyTorch modules""" """Top level Transformer Engine PyTorch modules"""
import os import os
import pickle
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Union, Optional, Callable, Tuple, Dict, List, Any, Mapping from typing import Union, Optional, Callable, Tuple, Dict, Any, Mapping
from functools import partial from functools import partial
from contextlib import contextmanager from contextlib import contextmanager
import numpy as np
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
...@@ -70,6 +72,8 @@ from .cpp_extensions import ( ...@@ -70,6 +72,8 @@ from .cpp_extensions import (
fp8_gelu, fp8_gelu,
fp8_cast_transpose_bgrad_dgelu_fused, fp8_cast_transpose_bgrad_dgelu_fused,
layernorm_fwd_fp8, layernorm_fwd_fp8,
layernorm_fwd_fp8_inf,
layernorm_fwd_inf,
cast_to_fp8, cast_to_fp8,
cast_from_fp8, cast_from_fp8,
) )
...@@ -192,8 +196,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -192,8 +196,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_meta_tensor(True) self.set_meta_tensor(True)
self.set_meta_tensor(False) self.set_meta_tensor(False)
def get_extra_state(self) -> Union[List[Any], None]: def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing.""" """Save before checkpointing."""
state = None
if self.fp8: if self.fp8:
state = {} state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
...@@ -210,10 +215,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -210,10 +215,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
extra[k] = v extra[k] = v
state["extra_fp8_variables"] = extra state["extra_fp8_variables"] = extra
return state state_serialized = pickle.dumps(state)
return None state_tensor = torch.tensor(np.frombuffer(state_serialized, dtype=np.uint8))
def set_extra_state(self, state: Union[List[Any], None]) -> None: return state_tensor
def set_extra_state(self, state: torch.Tensor) -> None:
"""Load previous state.""" """Load previous state."""
if state is None: if state is None:
return return
...@@ -252,6 +259,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -252,6 +259,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["autocast_id_bwd"] = state[9] self.fp8_meta["autocast_id_bwd"] = state[9]
return return
if isinstance(state, torch.Tensor):
state = pickle.loads(state.detach().numpy().tobytes())
if state is None:
return
# Restore global FP8 buffer states. # Restore global FP8 buffer states.
set_global_fp8_buffer(state["global_fp8_buffer"]) set_global_fp8_buffer(state["global_fp8_buffer"])
set_global_fp8_recompute_buffer(state["global_fp8_recompute_buffer"]) set_global_fp8_recompute_buffer(state["global_fp8_recompute_buffer"])
...@@ -541,13 +553,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -541,13 +553,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_dtype_backward, fp8_dtype_backward,
) )
else: else:
grad_output_t = None
grad_output_c = cast_to_fp8( grad_output_c = cast_to_fp8(
grad_output_mat, grad_output_mat,
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1, tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
) )
grad_output_t = None
grad_bias = None grad_bias = None
return grad_output_mat, grad_output_c, grad_output_t, grad_bias return grad_output_mat, grad_output_c, grad_output_t, grad_bias
...@@ -557,6 +569,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -557,6 +569,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Needs override.""" """Needs override."""
class _LayerNormLinear(torch.autograd.Function): class _LayerNormLinear(torch.autograd.Function):
"""LayerNormLinear semi-top level module """LayerNormLinear semi-top level module
Calls custom cuda extensions. Calls custom cuda extensions.
...@@ -584,6 +597,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -584,6 +597,7 @@ class _LayerNormLinear(torch.autograd.Function):
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
return_layernorm_output: bool, return_layernorm_output: bool,
is_training: bool
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -604,19 +618,37 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -604,19 +618,37 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output: if not return_layernorm_output:
ln_out, mu, rsigma = layernorm_fwd_fp8( if is_training:
inputmat, ln_out, mu, rsigma = layernorm_fwd_fp8(
ln_weight, inputmat,
ln_bias, ln_weight,
eps, ln_bias,
fp8_meta["scaling_fwd"], eps,
tex.FP8FwdTensors.GEMM1_INPUT, fp8_meta["scaling_fwd"],
fp8_dtype_forward, tex.FP8FwdTensors.GEMM1_INPUT,
) fp8_dtype_forward,
)
else:
mu = rsigma = None
ln_out = layernorm_fwd_fp8_inf(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else: else:
ln_out_return, mu, rsigma = tex.layernorm_fwd( if is_training:
inputmat, ln_weight, ln_bias, eps ln_out_return, mu, rsigma = tex.layernorm_fwd(
) inputmat, ln_weight, ln_bias, eps
)
else:
ln_out_return, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps
), None, None
ln_out = cast_to_fp8( ln_out = cast_to_fp8(
ln_out_return, ln_out_return,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
...@@ -624,7 +656,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -624,7 +656,12 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
else: else:
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps) if is_training:
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps)
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps
), None, None
ln_out_return = ln_out ln_out_return = ln_out
# Column Parallel Linear # Column Parallel Linear
...@@ -642,21 +679,31 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -642,21 +679,31 @@ class _LayerNormLinear(torch.autograd.Function):
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights: if update_fp8_weights:
fp8_cast_transpose_fused( if is_training:
weight, fp8_cast_transpose_fused(
fp8_meta["scaling_fwd"], weight,
tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_meta["scaling_fwd"],
fp8_dtype_forward, tex.FP8FwdTensors.GEMM1_WEIGHT,
cast_out=weight_fp8, fp8_dtype_forward,
transpose_out=weight_t_fp8, cast_out=weight_fp8,
) transpose_out=weight_t_fp8,
)
else:
weight_t_fp8 = None
weight_fp8 = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward)
out = fp8_gemm( out = fp8_gemm(
weight_fp8, weight_fp8,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT], fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
ln_out_total, ln_out_total,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT], fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
...@@ -678,29 +725,30 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -678,29 +725,30 @@ class _LayerNormLinear(torch.autograd.Function):
use_bias=use_bias, use_bias=use_bias,
) )
ctx.save_for_backward( if is_training:
inputmat, ctx.save_for_backward(
ln_weight, inputmat,
mu, ln_weight,
rsigma, mu,
weight, rsigma,
weight_t_fp8, weight,
ln_out, weight_t_fp8,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, ln_out,
) fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8 ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output = return_layernorm_output
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
...@@ -715,6 +763,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -715,6 +763,7 @@ class _LayerNormLinear(torch.autograd.Function):
return out, ln_out_return.view_as(inp) return out, ln_out_return.view_as(inp)
return out return out
@staticmethod @staticmethod
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
...@@ -768,10 +817,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -768,10 +817,12 @@ class _LayerNormLinear(torch.autograd.Function):
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad = fp8_gemm( dgrad = fp8_gemm(
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT], fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
grad_output_c, grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1], ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
...@@ -804,12 +855,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -804,12 +855,12 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad = fp8_gemm( wgrad = fp8_gemm(
ln_out_total_t, ln_out_total_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT], fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
grad_output_t, grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[ ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1,
],
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
...@@ -890,6 +941,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -890,6 +941,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -1051,7 +1103,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1051,7 +1103,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
if self.parallel_mode == "column": if self.parallel_mode == "column":
set_tensor_model_parallel_attributes(self.bias, True, 0, 1) set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
else: else:
self.register_buffer("bias", torch.Tensor(), persistent=False) self.register_buffer("bias", torch.Tensor().type(params_dtype), persistent=False)
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
...@@ -1110,7 +1162,13 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1110,7 +1162,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = bias if bias is not None else self.bias bias_tensor = bias if bias is not None else self.bias
out = _LayerNormLinear.apply( if self.training:
fwd_fn = _LayerNormLinear.apply
args = []
else:
fwd_fn = _LayerNormLinear.forward
args = [None]
args += (
inp, inp,
self.layer_norm_weight, self.layer_norm_weight,
self.layer_norm_bias, self.layer_norm_bias,
...@@ -1130,7 +1188,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1130,7 +1188,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
self.return_layernorm_output, self.return_layernorm_output,
self.training,
) )
out = fwd_fn(*args)
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
...@@ -1146,7 +1206,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1146,7 +1206,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
return out, ln_out return out, ln_out
return out return out
class _Linear(torch.autograd.Function): class _Linear(torch.autograd.Function):
"""Linear semi-top level module """Linear semi-top level module
Calls custom cuda extensions. Calls custom cuda extensions.
...@@ -1170,6 +1229,7 @@ class _Linear(torch.autograd.Function): ...@@ -1170,6 +1229,7 @@ class _Linear(torch.autograd.Function):
tensor_parallel: bool, tensor_parallel: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
is_training: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] in_features = weight.shape[-1]
...@@ -1186,19 +1246,27 @@ class _Linear(torch.autograd.Function): ...@@ -1186,19 +1246,27 @@ class _Linear(torch.autograd.Function):
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not fp8_meta["recipe"].override_linear_precision.wgrad: if not fp8_meta["recipe"].override_linear_precision.wgrad:
inputmat, inputmat_t = fp8_cast_transpose_fused( if is_training:
inputmat, inputmat, inputmat_t = fp8_cast_transpose_fused(
fp8_meta["scaling_fwd"], inputmat,
tex.FP8FwdTensors.GEMM1_INPUT, fp8_meta["scaling_fwd"],
fp8_dtype_forward, tex.FP8FwdTensors.GEMM1_INPUT,
) fp8_dtype_forward,
)
else:
inputmat = cast_to_fp8(
inputmat,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else: else:
inputmat = cast_to_fp8( inputmat, inputmat_t = cast_to_fp8(
inputmat, inputmat,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
) ), None
# Column Parallel Linear # Column Parallel Linear
if parallel_mode == "column" and sequence_parallel: if parallel_mode == "column" and sequence_parallel:
...@@ -1215,21 +1283,32 @@ class _Linear(torch.autograd.Function): ...@@ -1215,21 +1283,32 @@ class _Linear(torch.autograd.Function):
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights: if update_fp8_weights:
fp8_cast_transpose_fused( if is_training:
weight, fp8_cast_transpose_fused(
fp8_meta["scaling_fwd"], weight,
tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_meta["scaling_fwd"],
fp8_dtype_forward, tex.FP8FwdTensors.GEMM1_WEIGHT,
cast_out=weight_fp8, fp8_dtype_forward,
transpose_out=weight_t_fp8, cast_out=weight_fp8,
) transpose_out=weight_t_fp8,
)
else:
weight_t_fp8 = None
weight_fp8 = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)
out = fp8_gemm( out = fp8_gemm(
weight_fp8, weight_fp8,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT], fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
inputmat, inputmat,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT], fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
...@@ -1251,28 +1330,29 @@ class _Linear(torch.autograd.Function): ...@@ -1251,28 +1330,29 @@ class _Linear(torch.autograd.Function):
use_bias=use_bias, use_bias=use_bias,
) )
ctx.save_for_backward( if is_training:
inputmat_no_fp8 ctx.save_for_backward(
if not fp8 or fp8_meta["recipe"].override_linear_precision.wgrad inputmat_no_fp8
else None, if not fp8 or fp8_meta["recipe"].override_linear_precision.wgrad
inputmat_t else None,
if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad inputmat_t
else None, if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
weight, else None,
weight_t_fp8, weight,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, weight_t_fp8,
) fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
ctx.activation_dtype = activation_dtype )
ctx.fp8 = fp8 ctx.activation_dtype = activation_dtype
ctx.fp8_meta = fp8_meta ctx.fp8 = fp8
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fp8_meta = fp8_meta
ctx.is_first_microbatch = is_first_microbatch ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.use_bias = use_bias ctx.is_first_microbatch = is_first_microbatch
ctx.sequence_parallel = sequence_parallel ctx.use_bias = use_bias
ctx.tensor_parallel = tensor_parallel ctx.sequence_parallel = sequence_parallel
ctx.inp_shape = inp.shape ctx.tensor_parallel = tensor_parallel
ctx.parallel_mode = parallel_mode ctx.inp_shape = inp.shape
ctx.tp_group = tp_group ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
...@@ -1283,6 +1363,7 @@ class _Linear(torch.autograd.Function): ...@@ -1283,6 +1363,7 @@ class _Linear(torch.autograd.Function):
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1]) return out.view(-1, *inp.shape[1:-1], out.shape[-1])
@staticmethod @staticmethod
def backward( def backward(
ctx, grad_output: torch.Tensor ctx, grad_output: torch.Tensor
...@@ -1339,10 +1420,12 @@ class _Linear(torch.autograd.Function): ...@@ -1339,10 +1420,12 @@ class _Linear(torch.autograd.Function):
# DGRAD # DGRAD
dgrad = fp8_gemm( dgrad = fp8_gemm(
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT], fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
grad_output_c, grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1], ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
...@@ -1374,12 +1457,12 @@ class _Linear(torch.autograd.Function): ...@@ -1374,12 +1457,12 @@ class _Linear(torch.autograd.Function):
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
wgrad = fp8_gemm( wgrad = fp8_gemm(
inputmat_t_total, inputmat_t_total,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT], fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
grad_output_t, grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[ ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1,
],
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
...@@ -1438,6 +1521,7 @@ class _Linear(torch.autograd.Function): ...@@ -1438,6 +1521,7 @@ class _Linear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -1575,7 +1659,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1575,7 +1659,7 @@ class Linear(TransformerEngineBaseModule):
if self.parallel_mode == "column": if self.parallel_mode == "column":
set_tensor_model_parallel_attributes(self.bias, True, 0, 1) set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
else: else:
self.register_buffer("bias", torch.Tensor(), persistent=False) self.register_buffer("bias", torch.Tensor().type(params_dtype), persistent=False)
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
...@@ -1629,7 +1713,13 @@ class Linear(TransformerEngineBaseModule): ...@@ -1629,7 +1713,13 @@ class Linear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = bias if bias is not None else self.bias bias_tensor = bias if bias is not None else self.bias
out = _Linear.apply( if self.training:
linear_fn = _Linear.apply
args = []
else:
linear_fn = _Linear.forward
args = [None]
args += (
weight if weight is not None else self.weight, weight if weight is not None else self.weight,
self.weight1_fp8 if self.fp8 else None, self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None, self.weight1_t_fp8 if self.fp8 else None,
...@@ -1645,7 +1735,9 @@ class Linear(TransformerEngineBaseModule): ...@@ -1645,7 +1735,9 @@ class Linear(TransformerEngineBaseModule):
self.tp_size > 1, self.tp_size > 1,
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
self.training,
) )
out = linear_fn(*args)
if self.gemm_bias_unfused_add: if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype) out = out + cast_if_needed(bias_tensor, self.activation_dtype)
...@@ -1687,6 +1779,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1687,6 +1779,7 @@ class _LayerNormMLP(torch.autograd.Function):
return_layernorm_output: bool, return_layernorm_output: bool,
bias_gelu_nvfusion: bool, bias_gelu_nvfusion: bool,
set_parallel_mode: bool, set_parallel_mode: bool,
is_training: bool
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -1706,15 +1799,26 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1706,15 +1799,26 @@ class _LayerNormMLP(torch.autograd.Function):
if fp8: if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output: if not return_layernorm_output:
ln_out, mu, rsigma = layernorm_fwd_fp8( if is_training:
inputmat, ln_out, mu, rsigma = layernorm_fwd_fp8(
ln_weight, inputmat,
ln_bias, ln_weight,
eps, ln_bias,
fp8_meta["scaling_fwd"], eps,
tex.FP8FwdTensors.GEMM1_INPUT, fp8_meta["scaling_fwd"],
fp8_dtype_forward, tex.FP8FwdTensors.GEMM1_INPUT,
) fp8_dtype_forward,
)
else:
ln_out = layernorm_fwd_fp8_inf(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else: else:
ln_out_return, mu, rsigma = tex.layernorm_fwd( ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps inputmat, ln_weight, ln_bias, eps
...@@ -1726,9 +1830,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1726,9 +1830,14 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
else: else:
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps) if is_training:
ln_out_return = ln_out ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps)
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps
), None, None
ln_out_return = ln_out
# Column Parallel Linear # Column Parallel Linear
if set_parallel_mode and sequence_parallel: if set_parallel_mode and sequence_parallel:
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
...@@ -1745,30 +1854,48 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1745,30 +1854,48 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_bias else fc2_bias fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_bias else fc2_bias
if update_fp8_weights: if update_fp8_weights:
fp8_cast_transpose_fused( if is_training:
fc1_weight, fp8_cast_transpose_fused(
fp8_meta["scaling_fwd"], fc1_weight,
tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_meta["scaling_fwd"],
fp8_dtype_forward, tex.FP8FwdTensors.GEMM1_WEIGHT,
cast_out=fc1_weight_fp8, fp8_dtype_forward,
transpose_out=fc1_weight_t_fp8, cast_out=fc1_weight_fp8,
) transpose_out=fc1_weight_t_fp8,
)
fp8_cast_transpose_fused( fp8_cast_transpose_fused(
fc2_weight, fc2_weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT, tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
cast_out=fc2_weight_fp8, cast_out=fc2_weight_fp8,
transpose_out=fc2_weight_t_fp8, transpose_out=fc2_weight_t_fp8,
) )
else:
fc1_weight_t_fp8 = None
fc1_weight_fp8 = cast_to_fp8(
fc1_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)
fc2_weight_t_fp8 = None
fc2_weight_fp8 = cast_to_fp8(
fc2_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
)
fc1_out = fp8_gemm( fc1_out = fp8_gemm(
fc1_weight_fp8, fc1_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT], fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
ln_out_total, ln_out_total,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT], fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
...@@ -1786,10 +1913,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1786,10 +1913,12 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_out = fp8_gemm( fc2_out = fp8_gemm(
fc2_weight_fp8, fc2_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT], fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
gelu_out, gelu_out,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM2_INPUT], fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
...@@ -1816,7 +1945,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1816,7 +1945,7 @@ class _LayerNormMLP(torch.autograd.Function):
gelu=not bias_gelu_nvfusion, gelu=not bias_gelu_nvfusion,
) )
if bias_gelu_nvfusion: if bias_gelu_nvfusion and is_training:
fc1_out, _, _ = fc1_outputs fc1_out, _, _ = fc1_outputs
gelu_out = bias_gelu_fused(fc1_out, fc1_bias) gelu_out = bias_gelu_fused(fc1_out, fc1_bias)
else: else:
...@@ -1830,35 +1959,35 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1830,35 +1959,35 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc2_bias, bias=fc2_bias,
use_bias=use_bias, use_bias=use_bias,
) )
if is_training:
ctx.save_for_backward( ctx.save_for_backward(
inputmat, inputmat,
ln_weight, ln_weight,
mu, mu,
rsigma, rsigma,
ln_out, ln_out,
fc1_out, fc1_out,
gelu_out, gelu_out,
fc1_weight, fc1_weight,
fc1_weight_t_fp8, fc1_weight_t_fp8,
fc2_weight, fc2_weight,
fc2_weight_t_fp8, fc2_weight_t_fp8,
fc1_bias, fc1_bias,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
) )
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8 ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.bias_gelu_nvfusion = bias_gelu_nvfusion ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output = return_layernorm_output
ctx.set_parallel_mode = set_parallel_mode ctx.set_parallel_mode = set_parallel_mode
# Row Parallel Linear # Row Parallel Linear
if set_parallel_mode and sequence_parallel: if set_parallel_mode and sequence_parallel:
...@@ -1873,6 +2002,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1873,6 +2002,7 @@ class _LayerNormMLP(torch.autograd.Function):
return fc2_out, ln_out_return.view_as(inp) return fc2_out, ln_out_return.view_as(inp)
return fc2_out return fc2_out
@staticmethod @staticmethod
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
...@@ -1931,10 +2061,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1931,10 +2061,12 @@ class _LayerNormMLP(torch.autograd.Function):
# FC2 DGRAD; Unconditional # FC2 DGRAD; Unconditional
fc2_dgrad = fp8_gemm( fc2_dgrad = fp8_gemm(
fc2_weight_t_fp8, fc2_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_WEIGHT], fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
grad_output_c, grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1], ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
...@@ -1947,12 +2079,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1947,12 +2079,12 @@ class _LayerNormMLP(torch.autograd.Function):
gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward) gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
fc2_wgrad = fp8_gemm( fc2_wgrad = fp8_gemm(
gelu_out_t, gelu_out_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_INPUT], fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
grad_output_t, grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[ ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1,
],
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
...@@ -2010,10 +2142,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2010,10 +2142,12 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 DGRAD: Unconditional # FC1 DGRAD: Unconditional
fc1_dgrad = fp8_gemm( fc1_dgrad = fp8_gemm(
fc1_weight_t_fp8, fc1_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT], fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
dgelu, dgelu,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT2], ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
...@@ -2078,12 +2212,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2078,12 +2212,12 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
fc1_wgrad = fp8_gemm( fc1_wgrad = fp8_gemm(
ln_out_total_t, ln_out_total_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT], fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
dgelu_t, dgelu_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[ ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT2 tex.FP8BwdTensors.GRAD_OUTPUT2,
],
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
...@@ -2178,6 +2312,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2178,6 +2312,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -2372,7 +2507,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2372,7 +2507,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
) )
) )
else: else:
self.register_buffer("fc2_bias", torch.Tensor(), persistent=False) self.register_buffer("fc2_bias", torch.Tensor().type(params_dtype), persistent=False)
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
...@@ -2423,7 +2558,13 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2423,7 +2558,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
""" """
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
out = _LayerNormMLP.apply( if self.training:
fwd_fn = _LayerNormMLP.apply
args = []
else:
fwd_fn = _LayerNormMLP.forward
args = [None]
args += (
inp, inp,
self.layer_norm_weight, self.layer_norm_weight,
self.layer_norm_bias, self.layer_norm_bias,
...@@ -2448,7 +2589,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2448,7 +2589,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.return_layernorm_output, self.return_layernorm_output,
self.bias_gelu_nvfusion, self.bias_gelu_nvfusion,
self.set_parallel_mode, self.set_parallel_mode,
self.training,
) )
out = fwd_fn(*args)
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
"""Fused scaled masked softmax functions""" """Fused scaled masked softmax functions"""
import os import os
from typing import Callable, Tuple, Union from typing import Callable, Tuple, Union
import torch import torch
from torch import nn from torch import nn
import torch._C._onnx as _C_onnx
from torch.onnx import _type_utils
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
...@@ -46,6 +47,36 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -46,6 +47,36 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
return input_grads, None return input_grads, None
@staticmethod
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledUpperTriangMaskedSoftmax symbolic method"""
def triangular_mask():
dtype = _type_utils.JitScalarType.INT64
ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype)
k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
mask = g.op("Trilu", ones, k, upper_i=1)
mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL)
return mask
# Captures the logic of function scaled_upper_triang_masked_softmax_warp_forward
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
mask = triangular_mask()
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
inv_mask = g.op("Sub", one, mask)
neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
softmax_mask = g.op("Mul", mask, neg_tenK)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
class ScaledMaskedSoftmax(torch.autograd.Function): class ScaledMaskedSoftmax(torch.autograd.Function):
""" """
...@@ -78,6 +109,35 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -78,6 +109,35 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
) )
return input_grads, None, None return input_grads, None, None
@staticmethod
def symbolic(
g: torch.Graph,
inputs: torch._C.Value,
mask: torch._C.Value,
scale: float) -> torch._C.Value:
"""ScaledMaskedSoftmax symbolic method"""
# Captures the logic of function scaled_masked_softmax_warp_forward.
# output = softmax(mask(input*scale)
# Computed as:
# masked_scaled = (1 - mask)*(input*scale)
# softmax_mask = mask * -10000
# output = softmax(masked_scaled + softmax_mask)
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
inv_mask = g.op("Sub", one, mask)
# Note: type is hard coded because softmax uses FP16 or BF16
neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
softmax_mask = g.op("Mul", mask, neg_tenK)
masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
class ScaledSoftmax(torch.autograd.Function): class ScaledSoftmax(torch.autograd.Function):
""" """
...@@ -107,6 +167,19 @@ class ScaledSoftmax(torch.autograd.Function): ...@@ -107,6 +167,19 @@ class ScaledSoftmax(torch.autograd.Function):
) )
return input_grads, None, None return input_grads, None, None
@staticmethod
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledSoftmax symbolic method"""
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
out = g.op("Softmax", scaled)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
class FusedScaleMaskSoftmax(nn.Module): class FusedScaleMaskSoftmax(nn.Module):
""" """
...@@ -163,7 +236,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -163,7 +236,7 @@ class FusedScaleMaskSoftmax(nn.Module):
and attn_batches % 4 == 0 # np * b must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4
): ):
if 0 <= sk <= 4096: if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(sk) batch_per_block = self.get_batch_per_block(int(sk))
if self.attn_mask_type == "causal": if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0: if attn_batches % batch_per_block == 0:
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
ONNX symbolic functions for Transformer Engine
Warnings of the type pasted below are a known Pytorch issue
(https://github.com/pytorch/pytorch/issues/81693):
tests/test_onnx_export.py::test_export_cast_ops[112]
/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py:649:
UserWarning: The shape inference of trt::TRT_FP8DequantizeLinear type is missing,
so it may result in wrong shape inference for the exported graph.
Please consider adding it in symbolic function. (Triggered internally at
/opt/pytorch/pytorch/torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1880.)
_C._jit_pass_onnx_graph_shape_type_inference(
Scale tensors are treated as lists ("fs") instead of tensors ("v") because we need to access
specific entries using the index passes as `fp8_tensor`. If you fail to do this you will get
the following error when accessing a sepcific scale element (e.g. `scale_inv[fp8_tensor]`):
TypeError: 'torch._C.Value' object is not subscriptable
"""
import torch
from torch.onnx import symbolic_helper, register_custom_op_symbolic
import torch._C._onnx as _C_onnx
import transformer_engine_extensions as tex
# This file registers custom op symbolic ONNX functions and does not export any symbols.
__all__ = []
# Custom ops spec version
VER = 1
UNSPECIFIED_TYPE = -1
def make_op_name(op_name: str) -> str:
"""custom op name"""
return "trt::" + op_name
def quantize(g, inputs, scale_inv, fp8_tensor):
"""Helper Function for Quantization"""
output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
# Q inputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the input if needed.
if inputs.type().scalarType() == "Half":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT)
scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor]))
q_op = g.op(
make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType(
inputs.type().with_dtype(torch.uint8).with_sizes(output_shape))
return q_op
def dequantize(g, inputs, scale_inv, fp8_tensor, otype):
"""Helper Function for Dequantization"""
output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor]))
out = g.op(make_op_name("TRT_FP8DequantizeLinear"), inputs, scale).setType(
inputs.type().with_dtype(torch.float32).with_sizes(output_shape))
# DQ outputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the output if needed.
if otype == int(tex.DType.kFloat16):
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for cast_to_fp8"""
# pylint: disable=unused-argument
return quantize(g, inputs, scale_inv, fp8_tensor)
@symbolic_helper.parse_args("v", "fs", "i", "i", "i")
def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype):
"""ONNX graph for cast_from_fp8"""
# pylint: disable=unused-argument
return dequantize(g, inputs, scale_inv, fp8_tensor, otype)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_gelu"""
# pylint: disable=unused-argument
gelu = torch.onnx.symbolic_opset9.gelu(g, inputs, "tanh")
out = quantize(g, gelu, scale_inv, fp8_tensor)
return out
@symbolic_helper.parse_args("v", "fs", "i", "i", "i",
"v", "fs", "i", "i", "i",
"v", "i", "v", "v", "i",
"v", "i", "i", "i")
def onnx_te_gemm(
g,
weight,
weight_scale_inverse,
weight_fp8_tensor,
weight_type,
trans_weight,
inputs,
input_scale_inverse,
input_fp8_tensor,
input_type,
trans_input,
out,
out_type,
bias,
pre_gelu_out,
grad,
workspace,
workspaceSize,
accumulate,
use_split_accumulator):
"""ONNX graph for te_gemm"""
# pylint: disable=unused-argument
is_fp16 = bias.type().scalarType() == "Half"
if input_type == int(tex.DType.kFloat8E4M3):
inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, UNSPECIFIED_TYPE)
if weight_type == int(tex.DType.kFloat8E4M3):
weight = dequantize(g, weight, weight_scale_inverse, weight_fp8_tensor, UNSPECIFIED_TYPE)
output = g.op("Gemm", inputs, weight, transA_i=trans_input, transB_i=trans_weight)
empty_tensor_size = [0]
bias_empty = torch.onnx.symbolic_helper._get_tensor_sizes(bias) == empty_tensor_size
pre_gelu_out_empty = torch.onnx.symbolic_helper._get_tensor_sizes(pre_gelu_out) \
== empty_tensor_size
if not bias_empty:
if pre_gelu_out_empty:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
output = g.op('Add', output, bias)
else:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
output = g.op('Add', output, bias)
output = torch.onnx.symbolic_opset9.gelu(g, output)
else:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return output
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i")
def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for layernorm_fwd_fp8"""
# pylint: disable=unused-argument
ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
return fp8_ln
@symbolic_helper.parse_args("v", "v", "v", "f")
def onnx_layernorm_fwd(g, inputs, weight, bias, eps):
"""ONNX graph for layernorm_fwd"""
# pylint: disable=unused-argument
normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
if normalized_shape is None:
ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs)
assert ndim is not None
normalized_shape = list(range(0, ndim))
# Normalization axis = 0, so normalized_shape uses all dims except dim = 0
normalized_shape = normalized_shape[1:]
ln = torch.onnx.symbolic_opset9.layer_norm(
g,
inputs,
normalized_shape,
weight,
bias,
eps,
False # cudnn_enable (not relevant)
)
return ln
register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER)
register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER)
register_custom_op_symbolic('tex_ts::fp8_gelu_ts', onnx_fp8_gelu, VER)
register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment