"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "493649f965ef8ba613a6c60d3881adeae1052518"
Unverified Commit 914f3841 authored by asfiyab-nvidia's avatar asfiyab-nvidia Committed by GitHub
Browse files

Fix BF16 ONNX export for successful ONNX Runtime Verification (#271)



* fix BF16 onnx export for ort verification
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarasfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com>

---------
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
Signed-off-by: default avatarasfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6b6823a1
...@@ -173,7 +173,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -173,7 +173,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
output_size[0] * output_size[1], output_size[0] * output_size[1],
output_size[2], output_size[2],
output_size[3], output_size[3],
dtype=query_layer.dtype, dtype=torch.float32 if is_in_onnx_export_mode() else query_layer.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
) )
......
...@@ -176,6 +176,7 @@ def onnx_te_gemm( ...@@ -176,6 +176,7 @@ def onnx_te_gemm(
"""ONNX graph for te_gemm""" """ONNX graph for te_gemm"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
is_fp16 = is_dtype_fp16(inputs) is_fp16 = is_dtype_fp16(inputs)
is_bf16 = is_dtype_bf16(inputs)
if input_type == int(tex.DType.kFloat8E4M3): if input_type == int(tex.DType.kFloat8E4M3):
inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, out_type) inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, out_type)
...@@ -200,6 +201,8 @@ def onnx_te_gemm( ...@@ -200,6 +201,8 @@ def onnx_te_gemm(
else: else:
if is_fp16: if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16) output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
elif is_bf16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.BFLOAT16)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment