Unverified Commit 914f3841 authored by asfiyab-nvidia's avatar asfiyab-nvidia Committed by GitHub
Browse files

Fix BF16 ONNX export for successful ONNX Runtime Verification (#271)



* fix BF16 onnx export for ort verification
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarasfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com>

---------
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
Signed-off-by: default avatarasfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6b6823a1
......@@ -173,7 +173,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
dtype=torch.float32 if is_in_onnx_export_mode() else query_layer.dtype,
device=torch.cuda.current_device(),
)
......
......@@ -176,6 +176,7 @@ def onnx_te_gemm(
"""ONNX graph for te_gemm"""
# pylint: disable=unused-argument
is_fp16 = is_dtype_fp16(inputs)
is_bf16 = is_dtype_bf16(inputs)
if input_type == int(tex.DType.kFloat8E4M3):
inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, out_type)
......@@ -200,6 +201,8 @@ def onnx_te_gemm(
else:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
elif is_bf16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.BFLOAT16)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment