Unverified Commit 574f1b41 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Fix layer_norm ONNX export (#293)



* Fix ONNX export of layer_norm

ONNX has a spec bug: ConstantOfShape supports all dtypes except for BF16.
To WAR we use dtype FP32 and then cast to BF16.

Will also issue a PR to the ONNX sig committee to change the spec in opset 20.
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* fix lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4daea528
......@@ -304,6 +304,20 @@ def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax,
def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
"""ONNX graph for layernorm_fwd"""
# pylint: disable=unused-argument
def ones_like(inp, dtype):
"""Returns a tensor filled with the scalar value 1, with the same size as input and
with dtype data-type"""
shape = g.op("Shape", inp)
# WAR ONNX spec: ConstantOfShape accepts all data types except for BF16. To WAR
# create a ConstantOfShape with type FP32 and then add a Cast to BF16.
is_bf16 = dtype == torch.bfloat16
one = g.op("ConstantOfShape", shape, value_t=torch.tensor([1],
dtype=torch.float32 if is_bf16 else dtype))
if is_bf16:
one = g.op("Cast", one, to_i=_C_onnx.TensorProtoDataType.BFLOAT16)
return one
normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
if normalized_shape is None:
ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs)
......@@ -314,8 +328,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
if zero_centered_gamma:
inputs_dtype = inputs.type().dtype()
shape = g.op("Shape", weight)
one = g.op("ConstantOfShape", shape, value_t=torch.tensor([1], dtype=inputs_dtype))
one = ones_like(weight, inputs_dtype)
weight = g.op("Add", weight, one)
axis = -len(normalized_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment