"docs/api/c/swizzle.rst" did not exist on "c9ea6be92948e1ec553037f1a04900617b9f7f6b"
Unverified Commit 3fbded65 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Fix softmax ONNX export (#282)



* Fix softmax ONNX export

* BF16 is validated using "fake i/o": ie. instead of using BF16 as input/output, use FP32 input/output and convert to/from BF16 in the forward method.

* Wrap softmax symbolic functions with conversion to/from FP32 to produce the same semantics as TE's softmax (compute is performed at FP32 precision regardless of input/output data type).
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* ONNX export Code refactoring

Share function compute_in_fp32 between softmax.py (softmax symbolic functions) and te_onnx_extensions.py (the rest of the symbolic functions).
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Fix exports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent fee8970f
...@@ -690,10 +690,6 @@ def test_export_layernorm( ...@@ -690,10 +690,6 @@ def test_export_layernorm(
# Softmax kernel only supports FP16 or BF16! # Softmax kernel only supports FP16 or BF16!
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"]) @pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"])
def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision): def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision
class Test_Softmax(nn.Module): class Test_Softmax(nn.Module):
def __init__(self, softmax_fn, fake_bf16_io, mask_inp=False): def __init__(self, softmax_fn, fake_bf16_io, mask_inp=False):
super().__init__() super().__init__()
...@@ -710,6 +706,9 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision ...@@ -710,6 +706,9 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision
) )
def forward(self, inp, mask): def forward(self, inp, mask):
if self.fake_bf16_io:
inp = inp.type(torch.bfloat16)
if self.fused_scaled_softmax: if self.fused_scaled_softmax:
ret = self.fused_scaled_softmax(inp, mask, self.scale) ret = self.fused_scaled_softmax(inp, mask, self.scale)
else: else:
...@@ -718,12 +717,14 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision ...@@ -718,12 +717,14 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision
else: else:
ret = self.softmax_fn.apply(inp, self.scale) ret = self.softmax_fn.apply(inp, self.scale)
if self.fake_bf16_io: if self.fake_bf16_io:
ret = ret.type(torch.float16) ret = ret.type(torch.float32)
return ret return ret
fake_bf16_io = precision == "fake-torch.bfloat16"
precision = torch.bfloat16 if fake_bf16_io else precision
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features = 64 in_features, hidden_size = 64, 256
hidden_size = 256
mask = None mask = None
input_names = ["input", "mask"] input_names = ["input", "mask"]
inp_shape = [hidden_size, in_features, in_features, in_features] inp_shape = [hidden_size, in_features, in_features, in_features]
...@@ -743,17 +744,18 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision ...@@ -743,17 +744,18 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision
elif softmax_fn == te.softmax.FusedScaleMaskSoftmax: elif softmax_fn == te.softmax.FusedScaleMaskSoftmax:
kernel_str = "TorchSoftmax" kernel_str = "TorchSoftmax"
model = Test_Softmax(softmax_fn, fake_bf16_io) model = Test_Softmax(softmax_fn, fake_bf16_io)
input_tensor = torch.randn(*inp_shape, device="cuda")
# WAR for BF16 test as ORT doesn't support BF16 IO: FP16 input for both BF16 and FP16 precision types input_tensor = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision)
input_tensor = input_tensor.half()
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fname = f"{kernel_str}{high_prec_str}.onnx" fname = f"{kernel_str}{high_prec_str}.onnx"
inp = (input_tensor, mask) inp = (input_tensor, mask)
do_export(model, inp, fname, input_names=input_names) do_export(model, inp, fname, input_names=input_names)
te_outputs = te_infer(model, inp, is_fp8=False) te_outputs = te_infer(model, inp, is_fp8=False)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if fake_bf16_io or precision != torch.bfloat16: if fake_bf16_io or precision != torch.bfloat16:
validate_result(fname, inp, model, atol=1e-3, input_names=input_names, te_outputs=te_outputs) atol = 5e-2 if fake_bf16_io else 1e-3
validate_result(fname, inp, model, atol=atol, input_names=input_names, te_outputs=te_outputs)
# Test dynamically generated softmax mask. # Test dynamically generated softmax mask.
...@@ -763,13 +765,13 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision ...@@ -763,13 +765,13 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision
def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision): def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision):
fake_bf16_io = precision == "fake-torch.bfloat16" fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode # reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision precision = torch.bfloat16 if fake_bf16_io else precision
class Test_Softmax(nn.Module): class Test_Softmax(nn.Module):
def __init__(self, use_onnx_mask_fn: bool, fake_bf16_io: bool): def __init__(self, use_onnx_mask_fn: bool, fake_bf16_io: bool):
super().__init__() super().__init__()
self.scale = 1 # arbitrary value self.scale=1 # arbitrary value
self.fake_bf16_io = fake_bf16_io self.fake_bf16_io=fake_bf16_io
# Use NVTE_MASKED_SOFTMAX_FUSION to force TE to use forward_torch_softmax # Use NVTE_MASKED_SOFTMAX_FUSION to force TE to use forward_torch_softmax
# even when is_in_onnx_export_mode()==False. # even when is_in_onnx_export_mode()==False.
os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0" os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0"
...@@ -780,9 +782,11 @@ def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision): ...@@ -780,9 +782,11 @@ def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision):
) )
def forward(self, inp, mask): def forward(self, inp, mask):
if self.fake_bf16_io:
inp = inp.type(torch.bfloat16)
ret = self.fused_scaled_softmax(inp, mask, self.scale) ret = self.fused_scaled_softmax(inp, mask, self.scale)
if self.fake_bf16_io: if self.fake_bf16_io:
ret = ret.type(torch.float16) ret = ret.type(torch.float)
return ret return ret
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
...@@ -790,9 +794,8 @@ def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision): ...@@ -790,9 +794,8 @@ def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision):
hidden_size = 256 hidden_size = 256
mask = None mask = None
inp_shape = [hidden_size, in_features, in_features, in_features] inp_shape = [hidden_size, in_features, in_features, in_features]
input_tensor = torch.randn(*inp_shape, device="cuda") input_tensor = torch.randn(
# WAR for BF16 test as ORT doesn't support BF16 IO: FP16 input for both BF16 and FP16 precision types *inp_shape, device="cuda", dtype=torch.float if fake_bf16_io else precision)
input_tensor = input_tensor.half()
inp = (input_tensor, mask) inp = (input_tensor, mask)
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
...@@ -815,7 +818,10 @@ def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision): ...@@ -815,7 +818,10 @@ def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision):
do_export(model, inp, fname, input_names=input_names) do_export(model, inp, fname, input_names=input_names)
serialize_inputs_outputs(fname, inp, te_outputs=te_outputs_default_mask, input_names=input_names) serialize_inputs_outputs(fname, inp, te_outputs=te_outputs_default_mask, input_names=input_names)
if fake_bf16_io or precision != torch.bfloat16: if fake_bf16_io or precision != torch.bfloat16:
validate_result(fname, inp, model_onnx_mask, atol=1e-3, input_names=input_names, te_outputs=te_outputs_default_mask) atol = 1e-2 if fake_bf16_io else 1e-3
validate_result(
fname, inp, model_onnx_mask, atol=atol,
input_names=input_names, te_outputs=te_outputs_default_mask)
@pytest.mark.parametrize("scale_factor", [1]) @pytest.mark.parametrize("scale_factor", [1])
......
...@@ -11,6 +11,8 @@ import torch._C._onnx as _C_onnx ...@@ -11,6 +11,8 @@ import torch._C._onnx as _C_onnx
from torch.onnx import _type_utils from torch.onnx import _type_utils
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.te_onnx_extensions import compute_in_fp32
THREADS_PER_WARP = 32 THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128 THREADS_PER_BLOCK = 128
...@@ -45,6 +47,13 @@ def _get_onnx_export_causal_mask( ...@@ -45,6 +47,13 @@ def _get_onnx_export_causal_mask(
return derived_mask return derived_mask
def fp32_compute(onnx_symbolic_fn):
"""A decorator that wraps an ONNX symoblic function with FP32 compute operators."""
def wrapper(g: torch.Graph, inp: torch._C.Value, scale: float, *args, **kwargs):
return compute_in_fp32(g, inp, onnx_symbolic_fn, scale, *args, **kwargs)
return wrapper
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
...@@ -77,6 +86,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -77,6 +86,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
return input_grads, None return input_grads, None
@staticmethod @staticmethod
@fp32_compute
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledUpperTriangMaskedSoftmax symbolic method""" """ScaledUpperTriangMaskedSoftmax symbolic method"""
def triangular_mask(): def triangular_mask():
...@@ -88,8 +98,6 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -88,8 +98,6 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
return mask return mask
# Captures the logic of function scaled_upper_triang_masked_softmax_warp_forward # Captures the logic of function scaled_upper_triang_masked_softmax_warp_forward
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
mask = triangular_mask() mask = triangular_mask()
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
inv_mask = g.op("Sub", one, mask) inv_mask = g.op("Sub", one, mask)
...@@ -102,8 +110,6 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -102,8 +110,6 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
masked_scaled = g.op("Mul", inv_mask, scaled) masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask) masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked) out = g.op("Softmax", masked)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out return out
...@@ -139,6 +145,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -139,6 +145,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return input_grads, None, None return input_grads, None, None
@staticmethod @staticmethod
@fp32_compute
def symbolic( def symbolic(
g: torch.Graph, g: torch.Graph,
inputs: torch._C.Value, inputs: torch._C.Value,
...@@ -151,8 +158,6 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -151,8 +158,6 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
# masked_scaled = (1 - mask)*(input*scale) # masked_scaled = (1 - mask)*(input*scale)
# softmax_mask = mask * -10000 # softmax_mask = mask * -10000
# output = softmax(masked_scaled + softmax_mask) # output = softmax(masked_scaled + softmax_mask)
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input) scaled = g.op("Mul", inputs, scale_input)
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
...@@ -163,8 +168,6 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -163,8 +168,6 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
masked_scaled = g.op("Mul", inv_mask, scaled) masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask) masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked) out = g.op("Softmax", masked)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out return out
...@@ -197,15 +200,12 @@ class ScaledSoftmax(torch.autograd.Function): ...@@ -197,15 +200,12 @@ class ScaledSoftmax(torch.autograd.Function):
return input_grads, None, None return input_grads, None, None
@staticmethod @staticmethod
@fp32_compute
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledSoftmax symbolic method""" """ScaledSoftmax symbolic method"""
if inputs.type().scalarType() == "BFloat16":
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input) scaled = g.op("Mul", inputs, scale_input)
out = g.op("Softmax", scaled) out = g.op("Softmax", scaled)
if inputs.type().scalarType() == "BFloat16":
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
return out return out
......
...@@ -107,18 +107,18 @@ def dequantize(g, inputs, scale_inv, fp8_tensor, otype): ...@@ -107,18 +107,18 @@ def dequantize(g, inputs, scale_inv, fp8_tensor, otype):
return out return out
def compute_in_fp32(g, inp, subgraph, cast_outp): def compute_in_fp32(g, inp, subgraph, *args, **kwargs):
"""Wrap subgraph with casts to/from FP32 so that its precision is FP32. """Wrap subgraph with casts to/from FP32 so that its precision is FP32.
If `inp` data type is not FP32, add a cast of `inp` to FP32 and feed that into `subgraph`. If `inp` data type is not FP32, add a cast of `inp` to FP32 and feed that into `subgraph`;
Then, if `cast_output` is true, cast subgraphs's output back to `inp` data type. then cast subgraphs's output back to `inp` data type.
""" """
inp_dtype = get_TensorProtoDataType(inp) inp_dtype = get_TensorProtoDataType(inp)
is_fp32 = inp_dtype == _type_utils.JitScalarType.FLOAT is_fp32 = inp_dtype == _type_utils.JitScalarType.FLOAT
if not is_fp32: if not is_fp32:
inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT)
sg_out = subgraph(inp) sg_out = subgraph(g, inp, *args, **kwargs)
if not is_fp32 and cast_outp: if not is_fp32:
sg_out = g.op("Cast", sg_out, to_i=inp_dtype) sg_out = g.op("Cast", sg_out, to_i=inp_dtype)
return sg_out return sg_out
...@@ -141,10 +141,9 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): ...@@ -141,10 +141,9 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype):
def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_gelu""" """ONNX graph for fp8_gelu"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
wrapped_gelu = lambda inps: torch.onnx.symbolic_opset9.gelu(g, inps, "tanh")
# TE computes GELU using float32 precision so wrap the GELU subgraph with # TE computes GELU using float32 precision so wrap the GELU subgraph with
# conversion to/from float32. # conversion to/from float32.
gelu = compute_in_fp32(g, inputs, wrapped_gelu, cast_outp=True) gelu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.gelu, "tanh")
if scale_inv: if scale_inv:
gelu = quantize(g, gelu, scale_inv, fp8_tensor) gelu = quantize(g, gelu, scale_inv, fp8_tensor)
return gelu return gelu
...@@ -154,8 +153,7 @@ def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): ...@@ -154,8 +153,7 @@ def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
def onnx_fp8_relu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): def onnx_fp8_relu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_relu""" """ONNX graph for fp8_relu"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
wrapped_relu = lambda inps: torch.onnx.symbolic_opset9.relu(g, inps) relu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.relu)
relu = compute_in_fp32(g, inputs, wrapped_relu, cast_outp=True)
if scale_inv: if scale_inv:
relu = quantize(g, relu, scale_inv, fp8_tensor) relu = quantize(g, relu, scale_inv, fp8_tensor)
return relu return relu
...@@ -176,8 +174,7 @@ def onnx_swiglu(g: jit_utils.GraphContext, inp, dim): ...@@ -176,8 +174,7 @@ def onnx_swiglu(g: jit_utils.GraphContext, inp, dim):
def onnx_fp8_swiglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): def onnx_fp8_swiglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_swiglu""" """ONNX graph for fp8_swiglu"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
wrapped_swiglu = lambda inps: onnx_swiglu(g, inps, 1) swiglu = compute_in_fp32(g, inputs, onnx_swiglu, 1)
swiglu = compute_in_fp32(g, inputs, wrapped_swiglu, cast_outp=True)
if scale_inv: if scale_inv:
swiglu = quantize(g, swiglu, scale_inv, fp8_tensor) swiglu = quantize(g, swiglu, scale_inv, fp8_tensor)
return swiglu return swiglu
...@@ -198,8 +195,7 @@ def onnx_reglu(g: jit_utils.GraphContext, inp, dim): ...@@ -198,8 +195,7 @@ def onnx_reglu(g: jit_utils.GraphContext, inp, dim):
def onnx_fp8_reglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): def onnx_fp8_reglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_reglu""" """ONNX graph for fp8_reglu"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
wrapped_reglu = lambda inps: onnx_reglu(g, inps, 1) reglu = compute_in_fp32(g, inputs, onnx_reglu, 1)
reglu = compute_in_fp32(g, inputs, wrapped_reglu, cast_outp=True)
if scale_inv: if scale_inv:
reglu = quantize(g, reglu, scale_inv, fp8_tensor) reglu = quantize(g, reglu, scale_inv, fp8_tensor)
return reglu return reglu
...@@ -221,8 +217,7 @@ def onnx_geglu(g: jit_utils.GraphContext, inp, dim): ...@@ -221,8 +217,7 @@ def onnx_geglu(g: jit_utils.GraphContext, inp, dim):
def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_geglu""" """ONNX graph for fp8_geglu"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
wrapped_geglu = lambda inps: onnx_geglu(g, inps, 1) geglu = compute_in_fp32(g, inputs, onnx_geglu, 1)
geglu = compute_in_fp32(g, inputs, wrapped_geglu, cast_outp=True)
if scale_inv: if scale_inv:
geglu = quantize(g, geglu, scale_inv, fp8_tensor) geglu = quantize(g, geglu, scale_inv, fp8_tensor)
return geglu return geglu
...@@ -276,10 +271,9 @@ def onnx_te_gemm( ...@@ -276,10 +271,9 @@ def onnx_te_gemm(
output = g.op("Gemm", inputs, weight, transA_i=trans_input, transB_i=trans_weight) output = g.op("Gemm", inputs, weight, transA_i=trans_input, transB_i=trans_weight)
if not bias_empty: if not bias_empty:
if not pre_gelu_out_empty: if not pre_gelu_out_empty:
wrapped_gelu = lambda inps: torch.onnx.symbolic_opset9.gelu(g, inps, "tanh")
# TE computes GELU using float32 precision so wrap the GELU subgraph with # TE computes GELU using float32 precision so wrap the GELU subgraph with
# conversion to/from float32. # conversion to/from float32.
output = compute_in_fp32(g, output, wrapped_gelu, cast_outp=True) output = compute_in_fp32(g, output, torch.onnx.symbolic_opset9.gelu, "tanh")
else: else:
if is_fp16: if is_fp16:
output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16) output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment