Unverified Commit 804f1203 authored by asfiyab-nvidia's avatar asfiyab-nvidia Committed by GitHub
Browse files

Fix BF16 ONNX export for successful ONNX Runtime Verification (#290)


Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
parent 0426feb6
......@@ -169,14 +169,19 @@ class UnfusedDotProductAttention(torch.nn.Module):
key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
# WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator
is_bf16 = query_layer.dtype == torch.bfloat16
matmul_result = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype,
device=torch.cuda.current_device(),
)
if is_in_onnx_export_mode() and is_bf16:
matmul_result = matmul_result.bfloat16()
scale = self.norm_factor
if apply_qk_layer_scaling:
scale *= self.layer_number
......
......@@ -254,6 +254,7 @@ def onnx_te_gemm(
"""ONNX graph for te_gemm"""
# pylint: disable=unused-argument
is_fp16 = is_dtype_fp16(inputs)
is_bf16 = is_dtype_bf16(inputs)
if input_type == int(tex.DType.kFloat8E4M3):
inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, out_type)
......@@ -277,6 +278,8 @@ def onnx_te_gemm(
else:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
elif is_bf16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.BFLOAT16)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment