Unverified Commit 6bccc76e authored by galagam's avatar galagam Committed by GitHub
Browse files

Test ONNX export - missing BF16 GEMM tests + output.json fix (#297)


Signed-off-by: default avatarGal Hubara Agam <ghubaraagam@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c56646e4
......@@ -211,7 +211,7 @@ def serialize_inputs_outputs(
json_fname = fname[:-len(".onnx")] + "_output.json"
named_outputs = zip(output_names, te_outputs)
output_data = {k: v.cpu() for k, v in named_outputs if v is not None}
output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None}
custom_outputs = RunResults()
custom_outputs.add([output_data], runner_name="custom_runner")
custom_outputs.save(json_fname)
......@@ -441,14 +441,18 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
"precision, use_fp8, use_bias, use_gelu", [
(torch.float32, False, False, False),
(torch.float16, False, False, False),
(torch.bfloat16, False, False, False),
(torch.float32, False, True, False),
(torch.float16, False, True, False),
(torch.bfloat16, False, True, False),
(torch.float32, False, True, True),
(torch.float16, False, True, True),
(torch.bfloat16, False, True, True),
# For FP8 GEMM GeLU is not used.
(torch.float32, True, False, False),
(torch.float16, True, False, False),
(torch.bfloat16, True, False, False),
# When enabling bias we must use float16 or bfloat16 (because of kernel limitations)
(torch.float16, True, True, False),
(torch.bfloat16, True, True, False),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment