"...include/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "a518bd5cf27f723cd1c2249287d43e3a7bb016c1"
Unverified Commit 2a1069f4 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Fix LayerNorm ONNX export (#174)



* iFix LN ONNX export

When exporting LayerNorm make sure that the weights and bias
inputs have the same type as the LN input.
Also:
 * Add a regression test.
 * Add environment variable to override directory of generated test artifacts
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* fix envvar
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 30212170
...@@ -50,7 +50,9 @@ if SAVE_TEST_IO: ...@@ -50,7 +50,9 @@ if SAVE_TEST_IO:
from polygraphy.comparator import RunResults from polygraphy.comparator import RunResults
# The directory where generated ONNX test models are stored. # The directory where generated ONNX test models are stored.
TEST_ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), "./gen_onnx_models") NVTE_TEST_ARTIFACTS_DIR = os.environ.get('NVTE_TEST_ARTIFACTS_DIR')
NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(tempfile.gettempdir(), "./gen_onnx_models")
# The directory where this file is stored. # The directory where this file is stored.
TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
...@@ -91,8 +93,8 @@ def do_export( ...@@ -91,8 +93,8 @@ def do_export(
) )
model.cuda().eval() model.cuda().eval()
os.makedirs(TEST_ARTIFACTS_DIR, exist_ok=True) os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True)
fname = os.path.join(TEST_ARTIFACTS_DIR, fname) fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,) inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,)
with te.onnx_export(True): with te.onnx_export(True):
torch.onnx.export( torch.onnx.export(
...@@ -234,7 +236,7 @@ def validate_result( ...@@ -234,7 +236,7 @@ def validate_result(
raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors") raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors")
# Run ORT session and TE model. # Run ORT session and TE model.
fname = os.path.join(TEST_ARTIFACTS_DIR, fname) fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
ort_s = create_ort_session(fname, is_fp8) ort_s = create_ort_session(fname, is_fp8)
input_feed = create_ort_input_dict(ort_s, inps) input_feed = create_ort_input_dict(ort_s, inps)
onnx_outputs = ort_s.run(None, input_feed=input_feed) onnx_outputs = ort_s.run(None, input_feed=input_feed)
...@@ -1018,6 +1020,128 @@ def test_export_transformer_layer( ...@@ -1018,6 +1020,128 @@ def test_export_transformer_layer(
elif precision != torch.float16: elif precision != torch.float16:
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8) validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8)
@pytest.mark.parametrize("use_fp8", [True])
@pytest.mark.parametrize("ln_scale_factor", [448*2])
@pytest.mark.parametrize("gemm_scale_factors", [(224, 224,),])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_gemm_layernorm(
use_fp8: bool,
ln_scale_factor: float,
gemm_scale_factors: Tuple[float, float],
precision: torch.dtype,
zero_centered_gamma: bool
):
"""This is a regression test for testing that all LN inputs have the same type.
The test sets up GEMM with FP32 output which feeds into an LN that is configured
with FP16 or BF16 weights and bias.
"""
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
class TestFP8_GemmLayernorm(nn.Module):
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda")
self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda")
self.eps = 1e-6 # An arbitrary small value
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(ln_scale_factor)
self.fp8_type = tex.DType.kFloat8E4M3
self.gemm = TestFP8_GEMM(
precision, use_bias=False, gelu=False, scale_factors=gemm_scale_factors)
def forward(self, inp, weight):
x = self.gemm(inp, weight)
x = texcpp.layernorm_fwd_fp8_inf(
x,
self.weight,
self.bias,
self.eps,
self.meta,
self.fp8_tensor,
self.fp8_type,
zero_centered_gamma)
x = cast_from_fp8(
x,
self.meta,
self.fp8_tensor,
self.fp8_type,
tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16)
return x
out_features = 128
hidden_size = 128
in_features = 128
class TestFP8_GEMM(nn.Module):
def __init__(self, precision, use_bias, gelu, scale_factors):
super().__init__()
self.use_bias = use_bias
self.gelu = gelu
self.precision = precision
self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
nb_inp_scales, nb_weight_scales = 1, out_features
act_scale_factor, weight_scale_factor = scale_factors
self.meta_inp = create_meta(act_scale_factor, nb_inp_scales)
self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales)
bias_size = nb_weight_scales
self.bias = torch.randn(bias_size, dtype=precision, device="cuda")
self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda")
self.inp_type = tex.DType.kFloat8E4M3
self.weights_type = tex.DType.kFloat8E4M3
self.outp_type = precision
def forward(self, inp, weight):
inp_fp8 = cast_to_fp8(
inp,
self.meta_inp,
self.fp8_tensor_inp,
self.inp_type)
weight_fp8 = cast_to_fp8(
weight,
self.meta_weight,
self.fp8_tensor_weight,
self.weights_type)
ret = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
self.inp_type,
inp_fp8,
self.meta_inp.scale_inv,
self.fp8_tensor_inp,
self.weights_type,
self.outp_type,
get_workspace(),
bias=self.bias,
use_bias=self.use_bias,
use_split_accumulator=False)
return ret
inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda")
weight = torch.randn(out_features, in_features, dtype=precision, device="cuda")
model = TestFP8_GemmLayernorm()
high_prec_str = dtype2str(precision)
fp8_str = f"_fp8" if use_fp8 else ""
fname = f"te.gemm_layernorm{fp8_str}{high_prec_str}.onnx"
do_export(model, (inp, weight), fname, use_fp8=use_fp8)
if precision not in (torch.bfloat16, ):
validate_result(
fname, (inp, weight), model, atol=5e-2, is_fp8=use_fp8, allow_cnt_errors=2)
@pytest.mark.parametrize("enabled", [True, False]) @pytest.mark.parametrize("enabled", [True, False])
def test_export_ctx_manager(enabled): def test_export_ctx_manager(enabled):
assert is_in_onnx_export_mode() == False assert is_in_onnx_export_mode() == False
......
...@@ -45,13 +45,40 @@ def make_op_name(op_name: str) -> str: ...@@ -45,13 +45,40 @@ def make_op_name(op_name: str) -> str:
return "trt::" + op_name return "trt::" + op_name
def get_TensorProtoDataType(t):
"""Return the _C_onnx.TensorProtoDataType of the input tensor"""
try:
return {
"Float": _C_onnx.TensorProtoDataType.FLOAT,
"Half": _C_onnx.TensorProtoDataType.FLOAT16,
"BFloat16": _C_onnx.TensorProtoDataType.BFLOAT16,
}[t.type().scalarType()]
except KeyError as e:
raise TypeError(f"Onnx export for dtype {t.type().scalarType()} not supported.") from e
def is_dtype_fp32(t):
"""Check fp32 dtype"""
return t.type().scalarType() == "Float"
def is_dtype_fp16(t):
"""Check fp16 dtype"""
return t.type().scalarType() == "Half"
def is_dtype_bf16(t):
"""Check bf16 dtype"""
return t.type().scalarType() == "BFloat16"
def quantize(g, inputs, scale_inv, fp8_tensor): def quantize(g, inputs, scale_inv, fp8_tensor):
"""Helper Function for Quantization""" """Helper Function for Quantization"""
output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
# Q inputs are currently constrained to FP32 due to a similar limitation in ORT # Q inputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the input if needed. # custom ops, so cast the input if needed.
if inputs.type().scalarType() == "Half" or inputs.type().scalarType() == "BFloat16": if not is_dtype_fp32(inputs):
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT) inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT)
scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor])) scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor]))
...@@ -84,15 +111,7 @@ def compute_in_fp32(g, inp, subgraph, cast_outp): ...@@ -84,15 +111,7 @@ def compute_in_fp32(g, inp, subgraph, cast_outp):
If `inp` data type is not FP32, add a cast of `inp` to FP32 and feed that into `subgraph`. If `inp` data type is not FP32, add a cast of `inp` to FP32 and feed that into `subgraph`.
Then, if `cast_output` is true, cast subgraphs's output back to `inp` data type. Then, if `cast_output` is true, cast subgraphs's output back to `inp` data type.
""" """
try: inp_dtype = get_TensorProtoDataType(inp)
inp_dtype = {
"Float": _C_onnx.TensorProtoDataType.FLOAT,
"Half": _C_onnx.TensorProtoDataType.FLOAT16,
"BFloat16": _C_onnx.TensorProtoDataType.BFLOAT16,
}[inp.type().scalarType()]
except KeyError as e:
raise TypeError(f"Onnx export for dtype {inp.type().scalarType()} not supported.") from e
is_fp32 = inp_dtype == _type_utils.JitScalarType.FLOAT is_fp32 = inp_dtype == _type_utils.JitScalarType.FLOAT
if not is_fp32: if not is_fp32:
inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT)
...@@ -158,7 +177,7 @@ def onnx_te_gemm( ...@@ -158,7 +177,7 @@ def onnx_te_gemm(
use_split_accumulator): use_split_accumulator):
"""ONNX graph for te_gemm""" """ONNX graph for te_gemm"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
is_fp16 = bias.type().scalarType() == "Half" is_fp16 = is_dtype_fp16(inputs)
if input_type == int(tex.DType.kFloat8E4M3): if input_type == int(tex.DType.kFloat8E4M3):
inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, UNSPECIFIED_TYPE) inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, UNSPECIFIED_TYPE)
...@@ -189,9 +208,16 @@ def onnx_te_gemm( ...@@ -189,9 +208,16 @@ def onnx_te_gemm(
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "b") @symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "b")
def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax,
scale_inv, fp8_tensor, otype, zero_centered_gamma): scale_inv, fp8_tensor, otype, zero_centered_gamma):
"""ONNX graph for layernorm_fwd_fp8""" """ONNX graph for layernorm_fwd_fp8"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
inp_dtype = get_TensorProtoDataType(inputs)
if inp_dtype != get_TensorProtoDataType(weight):
weight = g.op("Cast", weight, to_i=inp_dtype)
if inp_dtype != get_TensorProtoDataType(bias):
bias = g.op("Cast", bias, to_i=inp_dtype)
ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma) ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
return fp8_ln return fp8_ln
...@@ -210,7 +236,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): ...@@ -210,7 +236,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
normalized_shape = normalized_shape[1:] normalized_shape = normalized_shape[1:]
if zero_centered_gamma: if zero_centered_gamma:
inputs_dtype= inputs.type().dtype() inputs_dtype = inputs.type().dtype()
one = g.op("Constant", value_t=torch.tensor([1.], dtype=inputs_dtype, device="cuda")) one = g.op("Constant", value_t=torch.tensor([1.], dtype=inputs_dtype, device="cuda"))
weight = g.op("Add", weight, one) weight = g.op("Add", weight, one)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment