Unverified Commit 48b31ca9 authored by galagam's avatar galagam Committed by GitHub
Browse files

ONNX export test - BF16 support (#256)



* add bf16 subgraph tests
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* changes:
1. Add normal mode BF16 tests for all subgraphs
2. Add fake BF16 tests for low-level subgraphs
3. Separate IO serialization from validation
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* ONNX export test - BF16 support part 1

TE infer returns torch.tensor, to support output of bf16 which is
currently not supported in numpy
Signed-off-by: default avatarGal Hubara Agam <ghubaraagam@nvidia.com>

* ONNX export test - BF16 support part 2

- Separate TE infer from serialize
- Fix serialize function to use full path
- Set unique filenames for fake bf16 (avoid overriding standard bf16)
- Remove overwriting fake_bf16_io value
Signed-off-by: default avatarGal Hubara Agam <ghubaraagam@nvidia.com>

* Export test: Slight tolerance increase in test_export_gpt_generation

Causes sporadic failures ~1% of all runs
Signed-off-by: default avatarGal Hubara Agam <ghubaraagam@nvidia.com>

* Remove GEMM fake-bf16 export test and patch to enable it
Signed-off-by: default avatarGal Hubara Agam <ghubaraagam@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
Signed-off-by: default avatarGal Hubara Agam <ghubaraagam@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d7704b98
......@@ -29,7 +29,7 @@ import numpy as np
import onnxruntime as ort
import torch
from torch import nn as nn
from typing import Union, Tuple, List
from typing import Optional, Union, Tuple, List
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import transformer_engine_extensions as tex
......@@ -133,7 +133,11 @@ def do_export(
def to_numpy(tensor):
return tensor.cpu().numpy()
if isinstance(tensor, torch.Tensor):
if tensor.dtype == torch.bfloat16:
tensor = tensor.type(torch.float32)
tensor = tensor.detach().cpu().numpy()
return tensor
def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int):
......@@ -148,17 +152,13 @@ def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int):
def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool):
"""Transformer Engine forward prpoagtation.
Return results after copying to the CPU and converting to numpy.
"""
"""Transformer Engine forward propagation."""
fp8_recipe = create_fp8_recipe()
with torch.inference_mode(), te.fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings():
te_outputs = model(*inps if isinstance(inps, tuple) else (inps,))
if not isinstance(te_outputs, tuple):
te_outputs = (te_outputs,)
te_outputs_np = [to_numpy(te_output) for te_output in te_outputs]
return te_outputs_np
return te_outputs
def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname):
......@@ -167,6 +167,8 @@ def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, al
# Compare ORT and PyTorch outputs.
for onnx_output, te_output in zip(onnx_outputs, te_outputs):
# np.isclose: abs(a - b) <= (atol + rtol * abs(b))
te_output = to_numpy(te_output)
onnx_output = to_numpy(onnx_output)
ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol)
mismatches = ac.nonzero()
mismatched_ids = [loc for loc in zip(*mismatches)]
......@@ -186,6 +188,33 @@ def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, al
if nb_errors > allow_cnt_errors:
raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors")
def serialize_inputs_outputs(
fname: str,
inputs: Union[Tuple[torch.Tensor], torch.Tensor],
te_outputs: List[torch.Tensor],
input_names: Optional[List[str]] = None,
output_names: Optional[List[str]] = None,
):
if not SAVE_TEST_IO:
return
fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
input_names = input_names or ["input"]
output_names = output_names or ["output"]
inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,)
named_inputs = zip(input_names, inputs)
input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}]
json_fname = fname[:-len(".onnx")] + "_inputs.json"
save_json(input_data, json_fname, description="custom input data")
json_fname = fname[:-len(".onnx")] + "_output.json"
named_outputs = zip(output_names, te_outputs)
output_data = {k: v.cpu() for k, v in named_outputs if v is not None}
custom_outputs = RunResults()
custom_outputs.add([output_data], runner_name="custom_runner")
custom_outputs.save(json_fname)
def validate_result(
fname: str,
......@@ -240,28 +269,6 @@ def validate_result(
inp_dict = dict(zip(input_names, inps))
return inp_dict
def serialize_inputs_outputs(fname, inputs, inputs_names, te_outputs, output_names):
if not SAVE_TEST_IO:
return
inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,)
named_inputs = zip(inputs_names, inputs)
input_data = [{k: to_numpy(v) for k, v in named_inputs if v is not None}]
json_fname = fname[:-len(".onnx")] + "_inputs.json"
save_json(input_data, json_fname, description="custom input data")
if "bf16" in fname:
return
json_fname = fname[:-len(".onnx")] + "_output.json"
named_outputs = zip(output_names, te_outputs)
output_data = dict()
for out_name, outp in named_outputs:
if outp is not None:
assert out_name not in output_data
output_data[out_name] = outp
custom_outputs = RunResults()
custom_outputs.add([output_data], runner_name="custom_runner")
custom_outputs.save(json_fname)
input_names = input_names or ["input"]
output_names = output_names or ["output"]
......@@ -273,7 +280,6 @@ def validate_result(
input_feed = create_ort_input_dict(ort_s, inps)
onnx_outputs = ort_s.run(None, input_feed=input_feed)
compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname)
serialize_inputs_outputs(fname, inps, input_names, te_outputs, output_names)
def create_meta(scale_factor: float, size: int=1):
......@@ -284,7 +290,10 @@ def create_meta(scale_factor: float, size: int=1):
return meta
def dtype2str(dtype: torch.dtype):
def dtype2str(dtype: torch.dtype, fake_bf16_io=False):
if fake_bf16_io:
assert dtype == torch.bfloat16
return "_fake_bf16"
return {
torch.float32: "_fp32",
torch.float16: "_fp16",
......@@ -321,9 +330,14 @@ Tests cases begin here.
"precision, atol", [
[torch.float32, 1e-7],
[torch.float16, 1e-7],
[torch.bfloat16, 5e-3]
[torch.bfloat16, 5e-3],
["fake-torch.bfloat16", 5e-3],
])
def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision
class TestFP8_QDQ(nn.Module):
def __init__(self, fake_bf16_io):
super().__init__()
......@@ -353,15 +367,17 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
fake_bf16_io = precision == torch.bfloat16
inp = torch.randn(hidden_size, in_features, device="cuda",
dtype=torch.float if fake_bf16_io else precision)
high_prec_str = dtype2str(precision)
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_QDQ(fake_bf16_io)
do_export(model, inp, fname)
validate_result(fname, inp, model, atol=atol, is_fp8=True)
te_outputs = te_infer(model, inp, is_fp8=True)
serialize_inputs_outputs(fname, inp, te_outputs)
if fake_bf16_io or precision != torch.bfloat16:
validate_result(fname, inp, model, atol=atol, is_fp8=True, te_outputs=te_outputs)
@skip_FP8
@pytest.mark.parametrize("scale_factor", [448])
......@@ -369,9 +385,14 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre
"precision, atol", [
[torch.float32, 1e-5],
[torch.float16, 1e-5],
[torch.bfloat16, 5e-3]
[torch.bfloat16, 5e-3],
["fake-torch.bfloat16", 5e-3]
])
def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision
class TestFP8_Gelu(nn.Module):
def __init__(self, fake_bf16_io):
super().__init__()
......@@ -400,14 +421,16 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
fake_bf16_io = precision == torch.bfloat16
inp = torch.randn(hidden_size, in_features, device="cuda",
dtype=torch.float if fake_bf16_io else precision)
high_prec_str = dtype2str(precision)
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_Gelu(fake_bf16_io)
do_export(model, inp, fname)
validate_result(fname, inp, model, rtol=0, atol=atol, is_fp8=True, allow_cnt_errors=2)
te_outputs = te_infer(model, inp, is_fp8=True)
serialize_inputs_outputs(fname, inp, te_outputs)
if fake_bf16_io or precision != torch.bfloat16:
validate_result(fname, inp, model, rtol=0, atol=atol, is_fp8=True, allow_cnt_errors=2, te_outputs=te_outputs)
@pytest.mark.parametrize("scale_factors",
......@@ -531,8 +554,8 @@ def test_export_gemm(
out_features = 128
hidden_size = 256
in_features = 64
inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda")
weight = torch.randn(out_features, in_features, dtype=precision, device="cuda")
inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision)
weight = torch.randn(out_features, in_features, device="cuda", dtype=precision)
fp8_str = "_fp8" if use_fp8 else ""
bias_str = "_bias" if use_bias else ""
gelu_str = "_gelu" if use_gelu else ""
......@@ -542,26 +565,46 @@ def test_export_gemm(
if use_fp8:
model = TestFP8_GEMM(precision, use_bias, use_gelu, scale_factors)
do_export(model, (inp, weight), fname, use_fp8, input_names=input_names)
if precision == torch.bfloat16:
return
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2, is_fp8=True, input_names=input_names)
te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
if precision != torch.bfloat16:
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2,
is_fp8=True, input_names=input_names, te_outputs=te_outputs)
else:
model = Test_GEMM(precision, use_bias, use_gelu)
do_export(model, (inp, weight), fname, use_fp8, input_names=input_names)
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2, input_names=input_names)
te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
if precision != torch.bfloat16:
validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2,
input_names=input_names, te_outputs=te_outputs)
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize(
"use_fp8, precision, atol", [
[False, torch.float32, 1e-7],
[False, torch.float16, 1e-7],
[False, torch.bfloat16, 1e-7],
[False, "fake-torch.bfloat16", 1e-7],
[True, torch.float32, 1e-7],
[True, torch.float16, 1e-7],
[True, torch.bfloat16, 1e-2],
[True, "fake-torch.bfloat16", 1e-2]
])
def test_export_layernorm(
seed_default_rng,
use_fp8: bool,
scale_factor: float,
precision: torch.dtype,
zero_centered_gamma: bool
zero_centered_gamma: bool,
atol: float
):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
......@@ -573,8 +616,10 @@ def test_export_layernorm(
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda")
self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda")
self.weight = torch.randn(*normalized_shape, device="cuda",
dtype=torch.float if fake_bf16_io else precision)
self.bias = torch.zeros(*normalized_shape, device="cuda",
dtype=torch.float if fake_bf16_io else precision)
self.eps = 1e-6 # An arbitrary small value
def forward(self, inp):
......@@ -590,8 +635,10 @@ def test_export_layernorm(
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda")
self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda")
self.weight = torch.randn(*normalized_shape, device="cuda",
dtype=torch.float32 if fake_bf16_io else precision)
self.bias = torch.zeros(*normalized_shape, device="cuda",
dtype=torch.float32 if fake_bf16_io else precision)
self.eps = 1e-6 # An arbitrary small value
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
......@@ -614,18 +661,22 @@ def test_export_layernorm(
self.meta,
self.fp8_tensor,
self.fp8_type,
tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16)
as_te_type(precision))
if fake_bf16_io:
ret = ret.type(torch.float32)
return ret
inp = torch.randn(*inp_shape, device="cuda", dtype=precision)
inp = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision)
model = TestFP8_Layernorm() if use_fp8 else Test_Layernorm()
high_prec_str = dtype2str(precision)
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fp8_str = f"_fp8-{scale_factor}" if use_fp8 else ""
fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx"
do_export(model, inp, fname, use_fp8=use_fp8)
if precision not in (torch.bfloat16, ):
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs)
if fake_bf16_io or precision != torch.bfloat16:
validate_result(
fname, inp, model, atol=1e-7, is_fp8=use_fp8, allow_cnt_errors=3)
fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs)
@skip_FP8
......@@ -636,15 +687,20 @@ def test_export_layernorm(
te.softmax.FusedScaleMaskSoftmax,
])
# Softmax kernel only supports FP16 or BF16!
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"])
def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision
class Test_Softmax(nn.Module):
def __init__(self, softmax_fn, mask_inp=False):
def __init__(self, softmax_fn, fake_bf16_io, mask_inp=False):
super().__init__()
self.softmax_fn = softmax_fn
self.scale = 8 # arbitrary value
self.mask_inp = mask_inp
self.fused_scaled_softmax = None
self.fake_bf16_io = fake_bf16_io
if self.softmax_fn == te.softmax.FusedScaleMaskSoftmax:
self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
attn_mask_type="causal",
......@@ -660,6 +716,8 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision
ret = self.softmax_fn.apply(inp, mask, self.scale)
else:
ret = self.softmax_fn.apply(inp, self.scale)
if self.fake_bf16_io:
ret = ret.type(torch.float16)
return ret
# Set dimensions (these are arbitrary).
......@@ -671,38 +729,46 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision
if softmax_fn == softmax_defs.ScaledUpperTriangMaskedSoftmax:
inp_shape = [hidden_size, in_features, in_features]
kernel_str = "ScaledUpperTriangMaskedSoftmax"
model = Test_Softmax(softmax_fn)
model = Test_Softmax(softmax_fn, fake_bf16_io)
elif softmax_fn == softmax_defs.ScaledMaskedSoftmax:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(hidden_size, 1, in_features, in_features, device="cuda", dtype=precision)
mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
kernel_str = "ScaledMaskedSoftmax"
model = Test_Softmax(softmax_fn, mask_inp=True)
model = Test_Softmax(softmax_fn, fake_bf16_io, mask_inp=True)
elif softmax_fn == softmax_defs.ScaledSoftmax:
kernel_str = "ScaledSoftmax"
model = Test_Softmax(softmax_fn)
model = Test_Softmax(softmax_fn, fake_bf16_io)
elif softmax_fn == te.softmax.FusedScaleMaskSoftmax:
kernel_str = "TorchSoftmax"
model = Test_Softmax(softmax_fn)
model = Test_Softmax(softmax_fn, fake_bf16_io)
input_tensor = torch.randn(*inp_shape, device="cuda")
input_tensor = input_tensor.to(torch.bfloat16) if precision == torch.bfloat16 else input_tensor.half()
high_prec_str = dtype2str(precision)
# WAR for BF16 test as ORT doesn't support BF16 IO: FP16 input for both BF16 and FP16 precision types
input_tensor = input_tensor.half()
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fname = f"{kernel_str}{high_prec_str}.onnx"
inp = (input_tensor, mask)
do_export(model, inp, fname, input_names=input_names)
if precision != torch.bfloat16:
validate_result(fname, inp, model, atol=1e-3, input_names=input_names)
te_outputs = te_infer(model, inp, is_fp8=False)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if fake_bf16_io or precision != torch.bfloat16:
validate_result(fname, inp, model, atol=1e-3, input_names=input_names, te_outputs=te_outputs)
# Test dynamically generated softmax mask.
# Softmax kernel only supports FP16 or BF16!
@skip_FP8
@pytest.mark.parametrize("precision", [torch.float16])
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"])
def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision
class Test_Softmax(nn.Module):
def __init__(self, use_onnx_mask_fn: bool):
def __init__(self, use_onnx_mask_fn: bool, fake_bf16_io: bool):
super().__init__()
self.scale = 1 # arbitrary value
self.fake_bf16_io = fake_bf16_io
# Use NVTE_MASKED_SOFTMAX_FUSION to force TE to use forward_torch_softmax
# even when is_in_onnx_export_mode()==False.
os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0"
......@@ -714,6 +780,8 @@ def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision):
def forward(self, inp, mask):
ret = self.fused_scaled_softmax(inp, mask, self.scale)
if self.fake_bf16_io:
ret = ret.type(torch.float16)
return ret
# Set dimensions (these are arbitrary).
......@@ -722,17 +790,18 @@ def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision):
mask = None
inp_shape = [hidden_size, in_features, in_features, in_features]
input_tensor = torch.randn(*inp_shape, device="cuda")
input_tensor = input_tensor.to(torch.bfloat16) if precision == torch.bfloat16 else input_tensor.half()
# WAR for BF16 test as ORT doesn't support BF16 IO: FP16 input for both BF16 and FP16 precision types
input_tensor = input_tensor.half()
inp = (input_tensor, mask)
high_prec_str = dtype2str(precision)
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
# Compare the outputs of TE when using the default softmax mask
# to the TE outputs produced when using the ONNX-compatible causal mask.
model = Test_Softmax(use_onnx_mask_fn=False)
model = Test_Softmax(use_onnx_mask_fn=False, fake_bf16_io=fake_bf16_io)
te_outputs_default_mask = te_infer(model, inp, is_fp8=True)
with te.onnx_export(True):
# ONNX export mode forces use of the ONNX-compatible causal mask.
model_onnx_mask = Test_Softmax(use_onnx_mask_fn=True)
model_onnx_mask = Test_Softmax(use_onnx_mask_fn=True, fake_bf16_io=fake_bf16_io)
te_outputs_onnx_mask = te_infer(model_onnx_mask, inp, is_fp8=True)
compare_outputs(te_outputs_default_mask, te_outputs_onnx_mask,
atol=0, rtol=0, max_errors_printed=10, allow_cnt_errors=0, fname="softmax masking")
......@@ -743,7 +812,8 @@ def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision):
kernel_str = "FusedScaleMaskSoftmax"
fname = f"{kernel_str}{high_prec_str}.onnx"
do_export(model, inp, fname, input_names=input_names)
if precision != torch.bfloat16:
serialize_inputs_outputs(fname, inp, te_outputs=te_outputs_default_mask, input_names=input_names)
if fake_bf16_io or precision != torch.bfloat16:
validate_result(fname, inp, model_onnx_mask, atol=1e-3, input_names=input_names, te_outputs=te_outputs_default_mask)
......@@ -760,7 +830,7 @@ def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision):
# Todo: cannot configure BF16 when bias is disabled (ORT issue?)
(torch.bfloat16, False),
# Todo: cannot configure BF16 when bias is enabled (ORT issue?)
# (torch.bfloat16, True),
(torch.bfloat16, True),
])
def test_export_linear(
seed_default_rng,
......@@ -816,13 +886,15 @@ def test_export_linear(
if use_fp8:
set_layer_scale(model.linear, scale_factor, num_gemms=1)
do_export(model, inp, fname, use_fp8)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16, ):
return
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
else:
validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8)
validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8, te_outputs=te_outputs)
@pytest.mark.parametrize("scale_factor", [112])
......@@ -836,6 +908,8 @@ def test_export_linear(
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_layernorm_linear(
......@@ -876,10 +950,14 @@ def test_export_layernorm_linear(
if use_fp8:
set_layer_scale(model, scale_factor, num_gemms=1)
do_export(model, inp, fname, use_fp8)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16, ):
return
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
elif precision != torch.bfloat16:
validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8)
validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8, te_outputs=te_outputs)
@pytest.mark.parametrize("scale_factor", [112])
......@@ -893,6 +971,8 @@ def test_export_layernorm_linear(
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_layernorm_mlp(
......@@ -933,10 +1013,14 @@ def test_export_layernorm_mlp(
if use_fp8:
set_layer_scale(model, scale_factor, num_gemms=2)
do_export(model, inp, fname, use_fp8)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs)
if precision in (torch.bfloat16, ):
return
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
else:
validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8)
validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8, te_outputs=te_outputs)
@skip_FP8
@pytest.mark.parametrize(
......@@ -946,6 +1030,9 @@ def test_export_layernorm_mlp(
(torch.float16, False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(torch.float16, True, "padding"), # calls ScaledMaskedSoftmax
(torch.float16, False, "padding"), # calls ScaledSoftmax
(torch.bfloat16, False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(torch.bfloat16, True, "padding"), # calls ScaledMaskedSoftmax
(torch.bfloat16, False, "padding"), # calls ScaledSoftmax
])
def test_export_core_attention(
seed_default_rng,
......@@ -988,7 +1075,11 @@ def test_export_core_attention(
fname,
input_names=input_names,
use_fp8=True)
validate_result(fname, inp, model, atol=1e-2, input_names=input_names)
te_outputs = te_infer(model, inp, is_fp8=True)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16, ):
return
validate_result(fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs)
test_configs_multihead_attention = [
......@@ -1010,7 +1101,7 @@ test_configs_attention_type = [
]
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize("input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type)
def test_export_multihead_attention(
......@@ -1083,11 +1174,18 @@ def test_export_multihead_attention(
do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names,
dynamic_axes={"hidden_states": {0: "seq", 1:"bs"},
"attention_output": {0: "seq", 1:"bs"}})
te_outputs = te_infer(model, inp_context, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp_context, te_outputs, input_names=input_names, output_names=output_names)
if precision in (torch.bfloat16, ):
return
if not use_fp8:
validate_result(fname, inp_context, model, atol=1e-3, input_names=input_names, output_names=output_names)
validate_result(fname, inp_context, model, atol=1e-3, input_names=input_names,
output_names=output_names, te_outputs=te_outputs)
else:
validate_result(fname, inp_context, model, atol=1e-2, is_fp8=use_fp8,
input_names=input_names, output_names=output_names, allow_cnt_errors=3)
input_names=input_names, output_names=output_names, allow_cnt_errors=3,
te_outputs=te_outputs)
# In GPT generative phase (inference) the input sequence is smaller than the maximum
# allowed sequence length and we want to test this condition.
......@@ -1111,7 +1209,7 @@ def test_export_multihead_attention(
#True, # TO DO: handle this
False
])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_transformer_layer(
......@@ -1161,10 +1259,16 @@ def test_export_transformer_layer(
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda')
do_export(model, inp, fname, use_fp8, input_names=input_names)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16, ):
return
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, input_names=input_names)
validate_result(fname, inp, model, atol=1e-3, input_names=input_names,
te_outputs=te_outputs)
else:
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8, input_names=input_names)
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8, input_names=input_names,
te_outputs=te_outputs)
@pytest.mark.parametrize("use_fp8", [True])
......@@ -1285,14 +1389,17 @@ def test_export_gemm_layernorm(
fname = f"te.gemm_layernorm{fp8_str}{high_prec_str}.onnx"
input_names = ['input', 'weight']
do_export(model, (inp, weight), fname, use_fp8=use_fp8, input_names=input_names)
te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
if precision not in (torch.bfloat16, ):
validate_result(
fname, (inp, weight), model, atol=5e-2, is_fp8=use_fp8, allow_cnt_errors=2, input_names=input_names)
fname, (inp, weight), model, atol=5e-2, is_fp8=use_fp8, allow_cnt_errors=2,
input_names=input_names, te_outputs=te_outputs)
@skip_FP8
@pytest.mark.parametrize("use_fp8", [True, False])
@pytest.mark.parametrize("precision", [torch.float16])
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [True])
def test_export_gpt_generation(
seed_default_rng,
......@@ -1346,13 +1453,21 @@ def test_export_gpt_generation(
input_names=input_names, output_names=output_names,
dynamic_axes={"input": {0: "seq", 1:"bs"},
"output": {0: "seq", 1:"bs"}, })
validate_result(fname, inp, model, atol=5e-3, is_fp8=use_fp8, input_names=input_names)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names, output_names=output_names)
if precision not in (torch.bfloat16, ):
validate_result(fname, inp, model, atol=6e-3, is_fp8=use_fp8, input_names=input_names,
te_outputs=te_outputs)
# "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8.
sequence_length = 1 if not use_fp8 else 8
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
inp = (input_tensor, attention_mask)
validate_result(fname, inp, model, atol=5e-3, is_fp8=use_fp8, input_names=input_names)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision not in (torch.bfloat16, ):
validate_result(fname, inp, model, atol=6e-3, is_fp8=use_fp8, input_names=input_names,
te_outputs=te_outputs)
@pytest.mark.parametrize("enabled", [True, False])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment