"vscode:/vscode.git/clone" did not exist on "8cc2ab22032b7918870e06eec7c9594363ee3a5a"
Unverified Commit 804f1203 authored by asfiyab-nvidia's avatar asfiyab-nvidia Committed by GitHub
Browse files

Fix BF16 ONNX export for successful ONNX Runtime Verification (#290)


Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
parent 0426feb6
...@@ -169,14 +169,19 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -169,14 +169,19 @@ class UnfusedDotProductAttention(torch.nn.Module):
key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk] # preallocting result tensor: [b * np, sq, sk]
# WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator
is_bf16 = query_layer.dtype == torch.bfloat16
matmul_result = torch.empty( matmul_result = torch.empty(
output_size[0] * output_size[1], output_size[0] * output_size[1],
output_size[2], output_size[2],
output_size[3], output_size[3],
dtype=query_layer.dtype, dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
) )
if is_in_onnx_export_mode() and is_bf16:
matmul_result = matmul_result.bfloat16()
scale = self.norm_factor scale = self.norm_factor
if apply_qk_layer_scaling: if apply_qk_layer_scaling:
scale *= self.layer_number scale *= self.layer_number
......
...@@ -254,6 +254,7 @@ def onnx_te_gemm( ...@@ -254,6 +254,7 @@ def onnx_te_gemm(
"""ONNX graph for te_gemm""" """ONNX graph for te_gemm"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
is_fp16 = is_dtype_fp16(inputs) is_fp16 = is_dtype_fp16(inputs)
is_bf16 = is_dtype_bf16(inputs)
if input_type == int(tex.DType.kFloat8E4M3): if input_type == int(tex.DType.kFloat8E4M3):
inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, out_type) inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, out_type)
...@@ -277,6 +278,8 @@ def onnx_te_gemm( ...@@ -277,6 +278,8 @@ def onnx_te_gemm(
else: else:
if is_fp16: if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16) output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
elif is_bf16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.BFLOAT16)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment