Unverified Commit 487871e2 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Revert "Fix BF16 ONNX export for successful ONNX Runtime Verification (#271)" (#275)

This reverts commit 914f3841

.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 914f3841
......@@ -173,7 +173,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=torch.float32 if is_in_onnx_export_mode() else query_layer.dtype,
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)
......
......@@ -176,7 +176,6 @@ def onnx_te_gemm(
"""ONNX graph for te_gemm"""
# pylint: disable=unused-argument
is_fp16 = is_dtype_fp16(inputs)
is_bf16 = is_dtype_bf16(inputs)
if input_type == int(tex.DType.kFloat8E4M3):
inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, out_type)
......@@ -201,8 +200,6 @@ def onnx_te_gemm(
else:
if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
elif is_bf16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.BFLOAT16)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment