Unverified Commit 4ae9c1a0 authored by asfiyab-nvidia's avatar asfiyab-nvidia Committed by GitHub
Browse files

Cast BF16 input/output types for FP8 Q/DQ ONNX ops (#165)



add cast for BF16 input/output types for Q/DQ ONNX ops
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 68fc78dd
......@@ -51,7 +51,7 @@ def quantize(g, inputs, scale_inv, fp8_tensor):
# Q inputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the input if needed.
if inputs.type().scalarType() == "Half":
if inputs.type().scalarType() == "Half" or inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT)
scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor]))
......@@ -73,6 +73,8 @@ def dequantize(g, inputs, scale_inv, fp8_tensor, otype):
# custom ops, so cast the output if needed.
if otype == int(tex.DType.kFloat16):
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
elif otype == int(tex.DType.kBFloat16):
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.BFLOAT16)
return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment