"vscode:/vscode.git/clone" did not exist on "e5781865289e6d3a6e4d56f7e15efea7b7e48963"
Unverified Commit 6bccc76e authored by galagam's avatar galagam Committed by GitHub
Browse files

Test ONNX export - missing BF16 GEMM tests + output.json fix (#297)


Signed-off-by: default avatarGal Hubara Agam <ghubaraagam@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c56646e4
...@@ -211,7 +211,7 @@ def serialize_inputs_outputs( ...@@ -211,7 +211,7 @@ def serialize_inputs_outputs(
json_fname = fname[:-len(".onnx")] + "_output.json" json_fname = fname[:-len(".onnx")] + "_output.json"
named_outputs = zip(output_names, te_outputs) named_outputs = zip(output_names, te_outputs)
output_data = {k: v.cpu() for k, v in named_outputs if v is not None} output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None}
custom_outputs = RunResults() custom_outputs = RunResults()
custom_outputs.add([output_data], runner_name="custom_runner") custom_outputs.add([output_data], runner_name="custom_runner")
custom_outputs.save(json_fname) custom_outputs.save(json_fname)
...@@ -441,14 +441,18 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa ...@@ -441,14 +441,18 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
"precision, use_fp8, use_bias, use_gelu", [ "precision, use_fp8, use_bias, use_gelu", [
(torch.float32, False, False, False), (torch.float32, False, False, False),
(torch.float16, False, False, False), (torch.float16, False, False, False),
(torch.bfloat16, False, False, False),
(torch.float32, False, True, False), (torch.float32, False, True, False),
(torch.float16, False, True, False), (torch.float16, False, True, False),
(torch.bfloat16, False, True, False),
(torch.float32, False, True, True), (torch.float32, False, True, True),
(torch.float16, False, True, True), (torch.float16, False, True, True),
(torch.bfloat16, False, True, True),
# For FP8 GEMM GeLU is not used. # For FP8 GEMM GeLU is not used.
(torch.float32, True, False, False), (torch.float32, True, False, False),
(torch.float16, True, False, False), (torch.float16, True, False, False),
(torch.bfloat16, True, False, False),
# When enabling bias we must use float16 or bfloat16 (because of kernel limitations) # When enabling bias we must use float16 or bfloat16 (because of kernel limitations)
(torch.float16, True, True, False), (torch.float16, True, True, False),
(torch.bfloat16, True, True, False), (torch.bfloat16, True, True, False),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment