Unverified Commit 68f60b89 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Add env. var. for efficient text-generation in inference (#214)



* Dynamically-generated causal attention mask (for ONNX export)

TE's default causal mask is square (seq_len, seq_len) and is
dynamically allocated for different sequence sizes. Dynamic
allocation and dictionary lookups are not supported by ONNX.
GPT generative phase uses rectangular masks.

This commit forces softmax to use `forward_torch_softmax` and
to dynamically generate an attention mask when exporting to ONNX.
The mask is generated w/o using conditional control-flow by generating
a  (k_seq_len, k_seq_len) mask and slicing it to (q_seq_len, k_seq_len)

An alternate implementation is to pre-allocate a mask of shape
(max_seq, max_seq) and to slice that. This solution is more performant
at the expense of space, but the problem is the TE doesn't have a concept
of max_seq.

* Add to test_export_softmax a test for te.softmax.FusedScaleMaskSoftmax.
* Add test_softmax_mask_fn to test that TE's default attention mask and
the new ONNX-compatible mask produce the same behavior.
* Add test_export_gpt_generation to test that the ONNX model can correctly
handle inputs with different shapes and that the attention mask it adjusted
on-the-fly to different sequence lengths.

Misc:
* Add a PRNG seeding fixture for more stability in tests.
* Add dynamic shapes for ONNX input/output tests.
* Allow validate_result to compare ORT output to pre-computed TE outputs.
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Add NVTE_ONNX_KVCACHE_MAX_SEQ_LEN for efficient text-generation in inference

* Introduce an environment variable (NVTE_ONNX_KVCACHE_MAX_SEQ_LEN) to set the maximum sequence length.
In ONNX inference with KV-Cache optimizations for GPT text generation, the attention mask shape can be square (context-phase) or rectangular (generation-phase).
When exporting to ONNX and this variable is set, TE preallocates an upper triangular (k=1) matrix with a size as prescribed by the variable, and dynamically slices the mask for the required shape.
TE models can be exported to ONNX when NVTE_ONNX_KVCACHE_MAX_SEQ_LEN is not configured, but the attention masking is always square and not fit for efficient text generation.

* Work-around torch.onnx.export bug that incorrectly folds
layer_norm(data, scale=add(gamma,1)) to layer_norm(data, scale=gamma)
when we use LN with zero-centered gamma.

* ONNX export tests
  * Add a fixture (seed_default_rng) to seed the PRNG
  * Add a fixture (set_max_seq_len) to set the max sequence length when exporting to ONNX for GPT text generation
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Fix linting errors
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Remove immutable default values from a couple of function signatures
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Add @skip_FP8 to test_export_gpt_generation
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Update transformer_engine/pytorch/softmax.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CI error for softmax export
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0d2021ef
......@@ -29,7 +29,7 @@ import numpy as np
import onnxruntime as ort
import torch
from torch import nn as nn
from typing import Union, Tuple
from typing import Union, Tuple, List
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import transformer_engine_extensions as tex
......@@ -71,6 +71,18 @@ fp8_available, reason_for_no_fp8 = is_fp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.fixture()
def seed_default_rng():
"""Reseed the PRNG for test reproducibility"""
torch.random.seed()
@pytest.fixture()
def set_max_seq_len(max_seq_len=128):
"""Set the maximum sequence length that can be used for attention masking"""
os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}"
def create_fp8_recipe():
return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
......@@ -81,11 +93,14 @@ def do_export(
fname: str,
use_fp8: bool=True,
opset: int=OPSET,
input_names: list=["input"],
output_names: list=["output"],
input_names: List[str]=None,
output_names: List[str]=None,
dynamic_axes: List[str]=None
):
"""Export to ONNX"""
fp8_recipe = create_fp8_recipe()
input_names = input_names or ["input"]
output_names = output_names or ["output"]
with torch.inference_mode(), te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings():
warnings.filterwarnings(
......@@ -109,13 +124,11 @@ def do_export(
inps,
fname,
verbose=True,
dynamic_axes=dynamic_axes,
opset_version=opset,
input_names=input_names,
output_names=output_names,
# Do not constant-fold because torch.onnx incorrectly folds
# layer_norm(data, scale=add(gamma,1)) to layer_norm(data, scale=gamma)
# when we use LN with zero-centered gamma.
do_constant_folding=False,
do_constant_folding=True,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH)
......@@ -148,6 +161,32 @@ def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tens
return te_outputs_np
def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname):
""" Compare ORT and TE outputs."""
assert len(onnx_outputs) == len(te_outputs)
# Compare ORT and PyTorch outputs.
for onnx_output, te_output in zip(onnx_outputs, te_outputs):
# np.isclose: abs(a - b) <= (atol + rtol * abs(b))
ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol)
mismatches = ac.nonzero()
mismatched_ids = [loc for loc in zip(*mismatches)]
if mismatched_ids:
# Log some information in case of error.
print("*" * 100)
nb_errors = len(mismatched_ids)
nb_vals = min(nb_errors, max_errors_printed)
print(f"Detected {nb_errors} diverging values (output shape={onnx_output.shape})")
print(f"Showing first {nb_vals} errors (ONNX -- TE):")
abs_err = np.abs(onnx_output - te_output)
errors = abs_err[mismatches]
for loc in mismatched_ids[:nb_vals]:
ref = te_output[loc]
print(f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} > {atol + rtol * abs(ref)}")
print(f"Max error: {np.max(errors)}")
if nb_errors > allow_cnt_errors:
raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors")
def validate_result(
fname: str,
inps: Union[Tuple[torch.Tensor], torch.Tensor],
......@@ -157,8 +196,9 @@ def validate_result(
max_errors_printed: int=10,
is_fp8: bool=False,
allow_cnt_errors: int=0,
input_names: list=["input"],
output_names: list=["output"],
input_names: List[str]=None,
output_names: List[str]=None,
te_outputs: List[torch.Tensor]=None,
):
"""Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX
representation using ONNX Runtime (ORT) and ensure they are close.
......@@ -171,6 +211,8 @@ def validate_result(
a very small number (0-3) of outliers. This is fine to do because these outliers are due to
small kernel implementation differences between TE and ORT and do not imply an incorrect ONNX
representation (the tests assume both ORT or TE kernels are correct).
Argument `te_outputs` can be used to provide pre-computed TE outputs.
"""
def create_ort_session(fname: str, is_fp8: bool):
......@@ -220,38 +262,17 @@ def validate_result(
custom_outputs.add([output_data], runner_name="custom_runner")
custom_outputs.save(json_fname)
def compare_outputs(onnx_outputs, te_outputs):
""" Compare ORT and TE outputs."""
assert len(onnx_outputs) == len(te_outputs)
# Compare ORT and PyTorch outputs.
for onnx_output, te_output in zip(onnx_outputs, te_outputs):
# np.isclose: abs(a - b) <= (atol + rtol * abs(b))
ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol)
mismatches = ac.nonzero()
mismatched_ids = [loc for loc in zip(*mismatches)]
if mismatched_ids:
# Log some information in case of error.
print("*" * 100)
nb_errors = len(mismatched_ids)
nb_vals = min(nb_errors, max_errors_printed)
print(f"Detected {nb_errors} diverging values (output shape={onnx_output.shape})")
print(f"Showing first {nb_vals} errors (ONNX -- TE):")
abs_err = np.abs(onnx_output - te_output)
errors = abs_err[mismatches]
for loc in mismatched_ids[:nb_vals]:
ref = te_output[loc]
print(f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} > {atol + rtol * abs(ref)}")
print(f"Max error: {np.max(errors)}")
if nb_errors > allow_cnt_errors:
raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors")
input_names = input_names or ["input"]
output_names = output_names or ["output"]
# Run ORT session and TE model.
fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
if not te_outputs:
te_outputs = te_infer(model, inps, is_fp8)
ort_s = create_ort_session(fname, is_fp8)
input_feed = create_ort_input_dict(ort_s, inps)
onnx_outputs = ort_s.run(None, input_feed=input_feed)
compare_outputs(onnx_outputs, te_outputs)
compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname)
serialize_inputs_outputs(fname, inps, input_names, te_outputs, output_names)
......@@ -302,7 +323,7 @@ Tests cases begin here.
[torch.float16, 1e-7],
[torch.bfloat16, 5e-3]
])
def test_export_cast_ops(scale_factor: float, atol: float, precision: torch.dtype):
def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype):
class TestFP8_QDQ(nn.Module):
def __init__(self, fake_bf16_io):
super().__init__()
......@@ -338,6 +359,7 @@ def test_export_cast_ops(scale_factor: float, atol: float, precision: torch.dtyp
high_prec_str = dtype2str(precision)
fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_QDQ(fake_bf16_io)
do_export(model, inp, fname)
validate_result(fname, inp, model, atol=atol, is_fp8=True)
......@@ -408,6 +430,7 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
(torch.bfloat16, True, True, False),
])
def test_export_gemm(
seed_default_rng,
precision, # Precision of inputs, weights, output and bias
use_fp8,
use_bias,
......@@ -533,6 +556,7 @@ def test_export_gemm(
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_layernorm(
seed_default_rng,
use_fp8: bool,
scale_factor: float,
precision: torch.dtype,
......@@ -605,26 +629,37 @@ def test_export_layernorm(
@skip_FP8
@pytest.mark.parametrize("softmax_def", [
@pytest.mark.parametrize("softmax_fn", [
softmax_defs.ScaledUpperTriangMaskedSoftmax,
softmax_defs.ScaledMaskedSoftmax,
softmax_defs.ScaledSoftmax,
te.softmax.FusedScaleMaskSoftmax,
])
# Softmax kernel only supports FP16 or BF16!
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16])
def test_export_softmax(softmax_def, precision):
def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision):
class Test_Softmax(nn.Module):
def __init__(self, softmax_function, mask_inp=False):
def __init__(self, softmax_fn, mask_inp=False):
super().__init__()
self.softmax_fn = softmax_function
self.softmax_fn = softmax_fn
self.scale = 8 # arbitrary value
self.mask_inp = mask_inp
self.fused_scaled_softmax = None
if self.softmax_fn == te.softmax.FusedScaleMaskSoftmax:
self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
attn_mask_type="causal",
mask_func=te.utils.attention_mask_func,
softmax_in_fp32=True,
)
def forward(self, inp, mask):
scale_factor = 8 # arbitrary value
if self.fused_scaled_softmax:
ret = self.fused_scaled_softmax(inp, mask, self.scale)
else:
if self.mask_inp:
ret = self.softmax_fn.apply(inp, mask, scale_factor)
ret = self.softmax_fn.apply(inp, mask, self.scale)
else:
ret = self.softmax_fn.apply(inp, scale_factor)
ret = self.softmax_fn.apply(inp, self.scale)
return ret
# Set dimensions (these are arbitrary).
......@@ -633,19 +668,22 @@ def test_export_softmax(softmax_def, precision):
mask = None
input_names = ["input", "mask"]
inp_shape = [hidden_size, in_features, in_features, in_features]
if softmax_def == softmax_defs.ScaledUpperTriangMaskedSoftmax:
if softmax_fn == softmax_defs.ScaledUpperTriangMaskedSoftmax:
inp_shape = [hidden_size, in_features, in_features]
kernel_str = "ScaledUpperTriangMaskedSoftmax"
model = Test_Softmax(softmax_def)
elif softmax_def == softmax_defs.ScaledMaskedSoftmax:
model = Test_Softmax(softmax_fn)
elif softmax_fn == softmax_defs.ScaledMaskedSoftmax:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(hidden_size, 1, in_features, in_features, device="cuda", dtype=precision)
mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
kernel_str = "ScaledMaskedSoftmax"
model = Test_Softmax(softmax_def, mask_inp=True)
elif softmax_def == softmax_defs.ScaledSoftmax:
model = Test_Softmax(softmax_fn, mask_inp=True)
elif softmax_fn == softmax_defs.ScaledSoftmax:
kernel_str = "ScaledSoftmax"
model = Test_Softmax(softmax_def)
model = Test_Softmax(softmax_fn)
elif softmax_fn == te.softmax.FusedScaleMaskSoftmax:
kernel_str = "TorchSoftmax"
model = Test_Softmax(softmax_fn)
input_tensor = torch.randn(*inp_shape, device="cuda")
input_tensor = input_tensor.to(torch.bfloat16) if precision == torch.bfloat16 else input_tensor.half()
high_prec_str = dtype2str(precision)
......@@ -656,6 +694,59 @@ def test_export_softmax(softmax_def, precision):
validate_result(fname, inp, model, atol=1e-3, input_names=input_names)
# Test dynamically generated softmax mask.
# Softmax kernel only supports FP16 or BF16!
@skip_FP8
@pytest.mark.parametrize("precision", [torch.float16])
def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision):
class Test_Softmax(nn.Module):
def __init__(self, use_onnx_mask_fn: bool):
super().__init__()
self.scale = 1 # arbitrary value
# Use NVTE_MASKED_SOFTMAX_FUSION to force TE to use forward_torch_softmax
# even when is_in_onnx_export_mode()==False.
os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0"
self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
attn_mask_type="causal",
mask_func=te.utils.attention_mask_func,
softmax_in_fp32=True,
)
def forward(self, inp, mask):
ret = self.fused_scaled_softmax(inp, mask, self.scale)
return ret
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
mask = None
inp_shape = [hidden_size, in_features, in_features, in_features]
input_tensor = torch.randn(*inp_shape, device="cuda")
input_tensor = input_tensor.to(torch.bfloat16) if precision == torch.bfloat16 else input_tensor.half()
inp = (input_tensor, mask)
high_prec_str = dtype2str(precision)
# Compare the outputs of TE when using the default softmax mask
# to the TE outputs produced when using the ONNX-compatible causal mask.
model = Test_Softmax(use_onnx_mask_fn=False)
te_outputs_default_mask = te_infer(model, inp, is_fp8=True)
with te.onnx_export(True):
# ONNX export mode forces use of the ONNX-compatible causal mask.
model_onnx_mask = Test_Softmax(use_onnx_mask_fn=True)
te_outputs_onnx_mask = te_infer(model_onnx_mask, inp, is_fp8=True)
compare_outputs(te_outputs_default_mask, te_outputs_onnx_mask,
atol=0, rtol=0, max_errors_printed=10, allow_cnt_errors=0, fname="softmax masking")
# Compare the outputs of TE when using the default softmax mask
# to the ORT ONNX outputs produced when using the ONNX-compatible causal mask.
input_names = ["input", "mask"]
kernel_str = "FusedScaleMaskSoftmax"
fname = f"{kernel_str}{high_prec_str}.onnx"
do_export(model, inp, fname, input_names=input_names)
if precision != torch.bfloat16:
validate_result(fname, inp, model_onnx_mask, atol=1e-3, input_names=input_names, te_outputs=te_outputs_default_mask)
@pytest.mark.parametrize("scale_factor", [1])
@pytest.mark.parametrize("use_fp8", [False, True])
# Returning the bias is a TE fusion optimization we don't care about.
......@@ -672,6 +763,7 @@ def test_export_softmax(softmax_def, precision):
# (torch.bfloat16, True),
])
def test_export_linear(
seed_default_rng,
scale_factor: float,
use_fp8: bool,
use_bias: bool,
......@@ -747,6 +839,7 @@ def test_export_linear(
])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_layernorm_linear(
seed_default_rng,
scale_factor: float,
use_fp8: bool,
use_bias: bool,
......@@ -803,6 +896,7 @@ def test_export_layernorm_linear(
])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_layernorm_mlp(
seed_default_rng,
scale_factor: float,
use_fp8: bool,
use_bias: bool,
......@@ -854,14 +948,15 @@ def test_export_layernorm_mlp(
(torch.float16, False, "padding"), # calls ScaledSoftmax
])
def test_export_core_attention(
seed_default_rng,
set_max_seq_len,
precision: torch.dtype,
use_mask: bool,
attn_mask_type: str,
):
# Set dimensions (these are arbitrary).
kv_channels = 64
num_attention_heads = 1
qkv_size = (2048, 4, num_attention_heads, kv_channels)
seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels)
query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
......@@ -919,6 +1014,8 @@ test_configs_attention_type = [
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize("input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type)
def test_export_multihead_attention(
seed_default_rng,
set_max_seq_len,
use_fp8: bool,
use_mask: bool,
attn_mask_type: str,
......@@ -949,8 +1046,8 @@ def test_export_multihead_attention(
init_method,
output_layer_init_method,
)
hidden_states = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
hidden_states_context = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
......@@ -961,9 +1058,6 @@ def test_export_multihead_attention(
if attention_type == "cross":
encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
inp = (hidden_states, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"]
output_names=["output", "output_1"]
fp8_str = "_fp8" if use_fp8 else ""
dtype_str = dtype2str(precision)
......@@ -982,16 +1076,35 @@ def test_export_multihead_attention(
attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params,
).to(device='cuda')
do_export(model, inp, fname, use_fp8, input_names=input_names, output_names=output_names)
inp_context = (hidden_states_context, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"]
output_names=["attention_output", "attention_bias"]
do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names,
dynamic_axes={"hidden_states": {0: "seq", 1:"bs"},
"attention_output": {0: "seq", 1:"bs"}})
if not use_fp8:
validate_result(fname, inp_context, model, atol=1e-3, input_names=input_names, output_names=output_names)
else:
validate_result(fname, inp_context, model, atol=1e-2, is_fp8=use_fp8,
input_names=input_names, output_names=output_names, allow_cnt_errors=3)
# In GPT generative phase (inference) the input sequence is smaller than the maximum
# allowed sequence length and we want to test this condition.
# Pretend that we're in generative phase when it makes sense (causal mask and self-attention).
is_generative_phase = (attn_mask_type == "causal" and attention_type == "self")
if is_generative_phase:
seq_len_offset = 8
hidden_states_generative = torch.randn(sequence_length-seq_len_offset, batch_size, hidden_size, dtype=precision, device="cuda")
inp_generative = (hidden_states_generative, attention_mask, encoder_output)
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, input_names=input_names, output_names=output_names)
elif precision == torch.float32:
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8,
input_names=input_names, output_names=output_names)
validate_result(fname, inp_generative, model, atol=1e-3, input_names=input_names, output_names=output_names)
else:
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8,
validate_result(fname, inp_generative, model, atol=1e-2, is_fp8=use_fp8,
input_names=input_names, output_names=output_names, allow_cnt_errors=3)
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("output_layernorm", [
......@@ -1002,6 +1115,8 @@ def test_export_multihead_attention(
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_transformer_layer(
seed_default_rng,
set_max_seq_len,
use_fp8: bool,
use_mask: bool,
attn_mask_type: str,
......@@ -1058,6 +1173,7 @@ def test_export_transformer_layer(
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_gemm_layernorm(
seed_default_rng,
use_fp8: bool,
ln_scale_factor: float,
gemm_scale_factors: Tuple[float, float],
......@@ -1174,6 +1290,71 @@ def test_export_gemm_layernorm(
fname, (inp, weight), model, atol=5e-2, is_fp8=use_fp8, allow_cnt_errors=2, input_names=input_names)
@skip_FP8
@pytest.mark.parametrize("use_fp8", [True, False])
@pytest.mark.parametrize("precision", [torch.float16])
@pytest.mark.parametrize("zero_centered_gamma", [True])
def test_export_gpt_generation(
seed_default_rng,
set_max_seq_len,
use_fp8: bool,
precision: torch.dtype,
zero_centered_gamma: bool
):
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask it adjusted on-the-fly to different sequence lengths.
"""
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Layer configuration
hidden_size = 64
sequence_length = 128
batch_size = 1
ffn_hidden_size = 256
num_attention_heads = 4
attention_mask = None
use_mask = True
attn_mask_type = "causal"
fuse_qkv_params = True
output_layernorm = False
fp8_str = "_fp8" if use_fp8 else ""
fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
high_prec_str = dtype2str(precision)
attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
fname = f"te.transformer_layer_generative{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}.onnx"
model = te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda')
# "Context phase": use full input sequence length
input_names = ["input"]
output_names = ["output"]
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
inp = (input_tensor,)
do_export(model, inp, fname, use_fp8,
input_names=input_names, output_names=output_names,
dynamic_axes={"input": {0: "seq", 1:"bs"},
"output": {0: "seq", 1:"bs"}, })
validate_result(fname, inp, model, atol=5e-3, is_fp8=use_fp8, input_names=input_names)
# "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8.
sequence_length = 1 if not use_fp8 else 8
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
inp = (input_tensor, attention_mask)
validate_result(fname, inp, model, atol=5e-3, is_fp8=use_fp8, input_names=input_names)
@pytest.mark.parametrize("enabled", [True, False])
def test_export_ctx_manager(enabled):
assert is_in_onnx_export_mode() == False
......
......@@ -9,8 +9,8 @@ import torch
from torch import nn
import torch._C._onnx as _C_onnx
from torch.onnx import _type_utils
import transformer_engine_extensions as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode
THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128
......@@ -25,6 +25,26 @@ def _get_default_causal_mask(sq: int) -> torch.Tensor:
return _default_causal_mask[sq]
def _get_onnx_export_causal_mask(
seq_q: int, seq_k: int, onnx_causal_mask: torch.Tensor
) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input, for ONNX export.
ONNX does not support dynamic control-flow and requires non-square masks when
using a KV-cache (seq_k's length len(context)+len(generative) while seq_q's length is 1).
Argument `onnx_causal_mask` is a square triu (k=1) mask that is sliced to the correct
shape for GPT context and generation phases.
In the context phase the derived mask is a square triu of shape (seq_k, seq_k), and in
the generation phase the mask is rectangular with shape (1, seq_k).
"""
assert len(onnx_causal_mask.size()) == 2
assert onnx_causal_mask.size(0) == onnx_causal_mask.size(1)
assert onnx_causal_mask.size(0) >= (seq_k-seq_q) >= 0
derived_mask = onnx_causal_mask[seq_k-seq_q:seq_k, :seq_k]
return derived_mask
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
......@@ -214,6 +234,17 @@ class FusedScaleMaskSoftmax(nn.Module):
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
# Users exporting to ONNX can optimize the attention mask for GPT text generation.
self.kvcache_max_seq = int(os.getenv("NVTE_ONNX_KVCACHE_MAX_SEQ_LEN", "-1"))
if self.kvcache_max_seq > 0:
self.register_buffer(
"onnx_causal_mask",
torch.triu(
torch.ones(self.kvcache_max_seq, self.kvcache_max_seq, device="cuda"),
diagonal=1
).bool(),
persistent=False)
def forward(
self,
inp: torch.Tensor,
......@@ -231,7 +262,7 @@ class FusedScaleMaskSoftmax(nn.Module):
scale is None or self.softmax_in_fp32
), "softmax should be in fp32 when scaled"
if self.is_kernel_available(*inp.size()):
if self.is_kernel_available(*inp.size()) and not is_in_onnx_export_mode():
return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale)
......@@ -287,7 +318,12 @@ class FusedScaleMaskSoftmax(nn.Module):
inp = inp * scale
if self.attn_mask_type == "causal":
mask = _get_default_causal_mask(inp.size()[2])
if is_in_onnx_export_mode() and self.kvcache_max_seq > 0:
seq_len_q, seq_len_k = inp.size(2), inp.size(3)
assert self.kvcache_max_seq >= seq_len_k
mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask)
else:
mask = _get_default_causal_mask(inp.size(2))
mask_output = self.mask_func(inp, mask) if mask is not None else inp
probs = torch.nn.Softmax(dim=-1)(mask_output)
......
......@@ -234,7 +234,8 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
if zero_centered_gamma:
inputs_dtype = inputs.type().dtype()
one = g.op("Constant", value_t=torch.tensor([1.], dtype=inputs_dtype, device="cuda"))
shape = g.op("Shape", weight)
one = g.op("ConstantOfShape", shape, value_t=torch.tensor([1], dtype=inputs_dtype))
weight = g.op("Add", weight, one)
axis = -len(normalized_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment