Unverified Commit 2aa3cd93 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

[RFC] Laying down building stone for more flexible ONNX export capabilities (#11786)



* Laying down building stone for more flexible ONNX export capabilities

* Ability to provide a map of config key to override before exporting.

* Makes it possible to export BART with/without past keys.

* Supports simple mathematical syntax for OnnxVariable.repeated

* Effectively apply value override from onnx config for model

* Supports export with additional features such as with-past for seq2seq

* Store the output path directly in the args for uniform usage across.

* Make BART_ONNX_CONFIG_* constants and fix imports.

* Support BERT model.

* Use tokenizer for more flexibility in defining the inputs of a model.

* Add TODO as remainder to provide the batch/sequence_length as CLI args

* Enable optimizations to be done on the model.

* Enable GPT2 + past

* Improve model validation with outputs containing nested structures

* Enable Roberta

* Enable Albert

* Albert requires opset >= 12

* BERT-like models requires opset >= 12

* Remove double printing.

* Enable XLM-Roberta

* Enable DistilBERT

* Disable optimization by default

* Fix missing setattr when applying optimizer_features

* Add value field to OnnxVariable to define constant input (not from tokenizers)

* Add T5 support.

* Simplify model type retrieval

* Example exporting token_classification pipeline for DistilBERT.

* Refactoring to package `transformers.onnx`

* Solve circular dependency & __main__

* Remove unnecessary imports in `__init__`

* Licences

* Use @Narsil's suggestion to forward the model's configuration to the ONNXConfig to avoid interpolation.

* Onnx export v2 fixes (#12388)

* Tiny fixes
Remove `convert_pytorch` from onnxruntime-less runtimes
Correct reference to model

* Style

* Fix Copied from

* LongFormer ONNX config.

* Removed optimizations

* Remvoe bad merge relicas.

* Remove unused constants.

* Remove some deleted constants from imports.

* Fix unittest to remove usage of PyTorch model for onnx.utils.

* Fix distilbert export

* Enable ONNX export test for supported model.

* Style.

* Fix lint.

* Enable all supported default models.

* GPT2 only has one output

* Fix bad property name when overriding config.

* Added unittests and docstrings.

* Disable with_past tests for now.

* Enable outputs validation for default export.

* Remove graph opt lvls.

* Last commit with on-going past commented.

* Style.

* Disabled `with_past` for now

* Remove unused imports.

* Remove framework argument

* Remove TFPreTrainedModel reference

* Add documentation

* Add onnxruntime tests to CircleCI

* Add test

* Rename `convert_pytorch` to `export`

* Use OrderedDict for dummy inputs

* WIP Wav2Vec2

* Revert "WIP Wav2Vec2"

This reverts commit f665efb04c92525c3530e589029f0ae7afdf603e.

* Style

* Use OrderedDict for I/O

* Style.

* Specify OrderedDict documentation.

* Style :)
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 0085e712
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" XLM-RoBERTa configuration """ """ XLM-RoBERTa configuration """
from collections import OrderedDict
from typing import Mapping
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
from ..roberta.configuration_roberta import RobertaConfig from ..roberta.configuration_roberta import RobertaConfig
...@@ -38,3 +41,19 @@ class XLMRobertaConfig(RobertaConfig): ...@@ -38,3 +41,19 @@ class XLMRobertaConfig(RobertaConfig):
""" """
model_type = "xlm-roberta" model_type = "xlm-roberta"
# Copied from transformers.models.roberta.configuration_roberta.RobertaOnnxConfig with Roberta->XLMRoberta
class XLMRobertaOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
# flake8: noqa
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast
from .convert import export, validate_model_outputs
from .utils import ParameterFormat, compute_serialized_parameters_size
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
from pathlib import Path
from typing import Callable, Tuple
from transformers.models.albert import AlbertOnnxConfig
from transformers.models.auto import AutoTokenizer
from transformers.models.bart import BartOnnxConfig
from transformers.models.bert import BertOnnxConfig
from transformers.models.distilbert import DistilBertOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from .. import is_torch_available
from ..utils import logging
from .convert import export, validate_model_outputs
if is_torch_available():
from transformers import AutoModel, PreTrainedModel
FEATURES_TO_AUTOMODELS = {
"default": AutoModel,
}
# Set of model topologies we support associated to the features supported by each topology and the factory
SUPPORTED_MODEL_KIND = {
"albert": {"default": AlbertOnnxConfig.default},
"bart": {"default": BartOnnxConfig.default},
"bert": {"default": BertOnnxConfig.default},
"distilbert": {"default": DistilBertOnnxConfig.default},
"gpt2": {"default": GPT2OnnxConfig.default},
"longformer": {"default": LongformerOnnxConfig.default},
"roberta": {"default": RobertaOnnxConfig},
"t5": {"default": T5OnnxConfig.default},
"xlm-roberta": {"default": XLMRobertaOnnxConfig.default},
}
def get_model_from_features(features: str, model: str):
"""
Attempt to retrieve a model from a model's name and the features to be enabled.
Args:
features: The features required
model: The name of the model to export
Returns:
"""
if features not in FEATURES_TO_AUTOMODELS:
raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}")
return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)
def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features
Args:
model: The model to export
features: The name of the features to check if they are avaiable
Returns:
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
"""
if model.config.model_type not in SUPPORTED_MODEL_KIND:
raise KeyError(
f"{model.config.model_type} ({model.name}) is not supported yet. "
f"Only {SUPPORTED_MODEL_KIND} are supported. "
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
)
# Look for the features
model_features = SUPPORTED_MODEL_KIND[model.config.model_type]
if features not in model_features:
raise ValueError(
f"{model.config.model_type} doesn't support features {features}. "
f"Supported values are: {list(model_features.keys())}"
)
return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features]
def main():
parser = ArgumentParser("Hugging Face ONNX Exporter tool")
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
parser.add_argument(
"--features",
choices=["default"],
default="default",
help="Export the model with some additional features.",
)
parser.add_argument(
"--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
)
parser.add_argument(
"--atol", type=float, default=1e-4, help="Absolute difference tolerence when validating the model."
)
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
# Retrieve CLI arguments
args = parser.parse_args()
args.output = args.output if args.output.is_file() else args.output.joinpath("model.onnx")
if not args.output.parent.exists():
args.output.parent.mkdir(parents=True)
# Allocate the model
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = get_model_from_features(args.features, args.model)
model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features)
onnx_config = model_onnx_config(model.config)
# Ensure the requested opset is sufficient
if args.opset < onnx_config.default_onnx_opset:
raise ValueError(
f"Opset {args.opset} is not sufficient to export {model_kind}. "
f"At least {onnx_config.default_onnx_opset} is required."
)
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output)
validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol)
logger.info(f"All good, model saved at: {args.output.as_posix()}")
if __name__ == "__main__":
logger = logging.get_logger("transformers.onnx") # pylint: disable=invalid-name
logger.setLevel(logging.INFO)
main()
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Mapping, Optional
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size
DEFAULT_ONNX_OPSET = 11
# 2 Gb
EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024
class OnnxConfig(ABC):
"""
Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
"""
DEFAULT_FIXED_BATCH = 2
DEFAULT_FIXED_SEQUENCE = 8
def __init__(self, config: PretrainedConfig):
self._config = config
@classmethod
def default(cls, config: PretrainedConfig) -> "OnnxConfig":
"""
Instantiate a OnnxConfig for a specific model
Args:
config: The model's configuration to use when exporting to ONNX
Returns:
OnnxConfig for this model
"""
return cls(config)
@property
@abstractmethod
def inputs(self) -> Mapping[str, Mapping[int, str]]:
"""
Mapping containing the axis definition of the input tensors to provide to the model
Returns:
For each input: its name associated to the axes symbolic name and the axis position within the tensor
"""
raise NotImplementedError()
@property
@abstractmethod
def outputs(self) -> Mapping[str, Mapping[int, str]]:
"""
Mapping containing the axis definition of the output tensors to provide to the model
Returns:
For each output: its name associated to the axes symbolic name and the axis position within the tensor
"""
raise NotImplementedError()
@property
def values_override(self) -> Optional[Mapping[str, Any]]:
"""
Dictionary of keys to override in the model's config before exporting
Returns:
Dictionary with the keys (and their corresponding values) to override
"""
if hasattr(self._config, "use_cache"):
return {"use_cache": False}
return None
@property
def default_batch_size(self) -> int:
"""
The default batch size to use if no other indication
Returns:
Integer > 0
"""
# Using 2 avoid ONNX making assumption about single sample batch
return OnnxConfig.DEFAULT_FIXED_BATCH
@property
def default_sequence_length(self) -> int:
"""
The default sequence length to use if no other indication
Returns:
Integer > 0
"""
return OnnxConfig.DEFAULT_FIXED_SEQUENCE
@property
def default_onnx_opset(self) -> int:
"""
Which onnx opset to use when exporting the model
Returns:
Integer ONNX Opset version
"""
return DEFAULT_ONNX_OPSET
@staticmethod
def use_external_data_format(num_parameters: int) -> bool:
"""
Flag indicating if the model requires using external data format
Args:
num_parameters: Number of parameter on the model
Returns:
True if model.num_parameters() * size_of(float32) >= 2Gb False otherwise
"""
return (
compute_serialized_parameters_size(num_parameters, ParameterFormat.Float)
>= EXTERNAL_DATA_FORMAT_SIZE_LIMIT
)
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
"""
Generate inputs to provide to the ONNX exporter for the specific framework
Args:
tokenizer: The tokenizer associated with this model configuration
batch_size: The batch size (int) to export the model for (-1 means dynamic axis)
seq_length: The sequence length (int) to export the model for (-1 means dynamic axis)
is_pair: Indicate if the input is a pair (sentence 1, sentence 2)
framework: The framework (optional) the tokenizer will generate tensor for
Returns:
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
"""
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
)
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add
)
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
return dict(tokenizer(dummy_input, return_tensors=framework))
class OnnxConfigWithPast(OnnxConfig, ABC):
def __init__(self, config: PretrainedConfig, use_past: bool = False):
super().__init__(config)
self.use_past = use_past
@classmethod
def with_past(cls, config: PretrainedConfig) -> "OnnxConfigWithPast":
"""
Instantiate a OnnxConfig with `use_past` attribute set to True
Args:
config: The underlying model's config to use when exporting to ONNX
Returns:
OnnxConfig with `.use_past = True`
"""
return cls(config, use_past=True)
@property
def values_override(self) -> Optional[Mapping[str, Any]]:
if hasattr(self._config, "use_cache"):
return {"use_cache": self.use_past}
return None
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=self.default_batch_size, num_token_to_add=0
)
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
# When use_past the caching mechanism requires inputs to be only 1 single token
fixed_sequence_length = 1 if self.use_past else self.default_sequence_length
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=fixed_sequence_length, num_token_to_add=token_to_add
)
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework)))
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from inspect import signature
from itertools import chain
from pathlib import Path
from typing import Iterable, List, Tuple, Union
import numpy as np
from packaging.version import Version, parse
from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
from ..utils import logging
from .config import OnnxConfig
from .utils import flatten_output_collection_property
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# This is the minimal required version to support some ONNX Runtime features
ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0")
def check_onnxruntime_requirements(minimum_version: Version):
"""
Check onnxruntime is installed and if the installed version match is recent enough
Raises:
ImportError: If onnxruntime is not installed or too old version is found
"""
try:
import onnxruntime
# Parse the version of the installed onnxruntime
ort_version = parse(onnxruntime.__version__)
# We require 1.4.0 minimum
if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:
raise ImportError(
f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
f"Please update onnxruntime by running `pip install --upgrade onnxruntime`"
)
except ImportError:
raise ImportError(
"onnxruntime doesn't seem to be currently installed. "
"Please install the onnxruntime by running `pip install onnxruntime`"
" and relaunch the conversion."
)
def export(
tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path
) -> Tuple[List[str], List[str]]:
"""
Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
Args:
tokenizer:
model:
config:
opset:
output:
Returns:
"""
if not is_torch_available():
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
import torch
from torch.onnx import export
logger.info(f"Using framework PyTorch: {torch.__version__}")
torch.set_grad_enabled(False)
model.config.return_dict = True
# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)
# Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())
if not inputs_match:
raise ValueError("Model and config inputs doesn't match")
# export can works with named args but the dict containing named args as to be last element of the args tuple
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)
return matched_inputs, onnx_outputs
def validate_model_outputs(
config: OnnxConfig,
tokenizer: PreTrainedTokenizer,
reference_model: Union[PreTrainedModel, TFPreTrainedModel],
onnx_model: Path,
onnx_named_outputs: List[str],
atol: float,
):
from onnxruntime import InferenceSession, SessionOptions
logger.info("Validating ONNX model...")
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
# Create ONNX Runtime session
options = SessionOptions()
session = InferenceSession(onnx_model.as_posix(), options)
# Compute outputs from the reference model
ref_outputs = reference_model(**reference_model_inputs)
ref_outputs_dict = {}
# We flatten potential collection of outputs (i.e. past_keys) to a flat structure
for name, value in ref_outputs.items():
if isinstance(value, (list, tuple)):
value = flatten_output_collection_property(name, value)
ref_outputs_dict.update(value)
else:
ref_outputs_dict[name] = value
# We flatten potential collection of inputs (i.e. past_keys)
onnx_inputs = {}
for name, value in reference_model_inputs.items():
if isinstance(value, (list, tuple)):
value = flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
else:
onnx_inputs[name] = value.numpy()
# Compute outputs from the ONNX model
onnx_outputs = session.run(onnx_named_outputs, onnx_inputs)
# Check we have a subset of the keys into onnx_outputs against ref_outputs
ref_outputs_set, onnx_outputs_set = set(ref_outputs_dict.keys()), set(onnx_named_outputs)
if not onnx_outputs_set.issubset(ref_outputs_set):
logger.info(
f"\t-[x] ONNX model outputs' name {onnx_outputs_set} doesn't match reference model {ref_outputs_set}"
)
raise ValueError(
"Outputs doesn't match between reference model and ONNX exported model: "
f"{onnx_outputs_set.difference(ref_outputs_set)}"
)
else:
logger.info(f"\t-[✓] ONNX model outputs' name match reference model ({onnx_outputs_set}")
# Check the shape and values match
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
ref_value = ref_outputs_dict[name].numpy()
logger.info(f'\t- Validating ONNX Model output "{name}":')
# Shape
if not ort_value.shape == ref_value.shape:
logger.info(f"\t\t-[x] shape {ort_value.shape} doesn't match {ref_value.shape}")
raise ValueError(
"Outputs shape doesn't match between reference model and ONNX exported model: "
f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)"
)
else:
logger.info(f"\t\t-[✓] {ort_value.shape} matchs {ref_value.shape}")
# Values
if not np.allclose(ref_value, ort_value, atol=atol):
logger.info(f"\t\t-[x] values not close enough (atol: {atol})")
raise ValueError(
"Outputs values doesn't match between reference model and ONNX exported model: "
f"Got max absolute difference of: {np.amax(np.abs(ref_value - ort_value))}"
)
else:
logger.info(f"\t\t-[✓] all values close (atol: {atol})")
def ensure_model_and_config_inputs_match(
model: Union[PreTrainedModel, TFPreTrainedModel], model_inputs: Iterable[str]
) -> Tuple[bool, List[str]]:
"""
:param model_inputs:
:param config_inputs:
:return:
"""
forward_parameters = signature(model.forward).parameters
model_inputs_set = set(model_inputs)
# We are fine if config_inputs has more keys than model_inputs
forward_inputs_set = set(forward_parameters.keys())
is_ok = model_inputs_set.issubset(forward_inputs_set)
# Make sure the input order match (VERY IMPORTANT !!!!)
matching_inputs = forward_inputs_set.intersection(model_inputs_set)
ordered_inputs = [parameter for parameter in forward_parameters.keys() if parameter in matching_inputs]
return is_ok, ordered_inputs
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ctypes import c_float, sizeof
from enum import Enum
from typing import Any, Dict, Iterable
class ParameterFormat(Enum):
Float = c_float
@property
def size(self) -> int:
"""
Number of byte required for this data type
Returns:
Integer > 0
"""
return sizeof(self.value)
def compute_effective_axis_dimension(dimension: int, fixed_dimension: int, num_token_to_add: int = 0) -> int:
"""
Args:
dimension:
fixed_dimension:
num_token_to_add:
Returns:
"""
# < 0 is possible if using a dynamic axis
if dimension <= 0:
dimension = fixed_dimension
dimension -= num_token_to_add
return dimension
def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterFormat) -> int:
"""
Compute the size taken by all the parameters in the given the storage format when serializing the model
Args:
num_parameters: Number of parameters to be saved
dtype: The data format each parameter will be saved
Returns:
Size (in byte) taken to save all the parameters
"""
return num_parameters * dtype.size
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
"""
Flatten any potential nested structure expanding the name of the field with the index of the element within the
structure.
Args:
name: The name of the nested structure
field: The structure to, potentially, be flattened
Returns:
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
"""
from itertools import chain
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}
...@@ -33,6 +33,7 @@ from .file_utils import ( ...@@ -33,6 +33,7 @@ from .file_utils import (
is_datasets_available, is_datasets_available,
is_faiss_available, is_faiss_available,
is_flax_available, is_flax_available,
is_keras2onnx_available,
is_onnx_available, is_onnx_available,
is_pandas_available, is_pandas_available,
is_rjieba_available, is_rjieba_available,
...@@ -234,6 +235,13 @@ def require_rjieba(test_case): ...@@ -234,6 +235,13 @@ def require_rjieba(test_case):
return test_case return test_case
def require_keras2onnx(test_case):
if not is_keras2onnx_available():
return unittest.skip("test requires keras2onnx")(test_case)
else:
return test_case
def require_onnx(test_case): def require_onnx(test_case):
if not is_onnx_available(): if not is_onnx_available():
return unittest.skip("test requires ONNX")(test_case) return unittest.skip("test requires ONNX")(test_case)
......
...@@ -35,7 +35,7 @@ from transformers.testing_utils import ( ...@@ -35,7 +35,7 @@ from transformers.testing_utils import (
_tf_gpu_memory_limit, _tf_gpu_memory_limit,
is_pt_tf_cross_test, is_pt_tf_cross_test,
is_staging_test, is_staging_test,
require_onnx, require_keras2onnx,
require_tf, require_tf,
slow, slow,
tooslow, tooslow,
...@@ -325,7 +325,7 @@ class TFModelTesterMixin: ...@@ -325,7 +325,7 @@ class TFModelTesterMixin:
self.assertEqual(len(incompatible_ops), 0, incompatible_ops) self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
@require_onnx @require_keras2onnx
@slow @slow
def test_onnx_runtime_optimize(self): def test_onnx_runtime_optimize(self):
if not self.test_onnx: if not self.test_onnx:
......
from pathlib import Path
from tempfile import NamedTemporaryFile
from unittest import TestCase
from unittest.mock import patch
from transformers import ( # LongformerConfig,
AlbertConfig,
AutoTokenizer,
BartConfig,
DistilBertConfig,
GPT2Config,
RobertaConfig,
T5Config,
XLMRobertaConfig,
is_torch_available,
)
from transformers.models.albert import AlbertOnnxConfig
from transformers.models.bart import BartOnnxConfig
from transformers.models.bert.configuration_bert import BertConfig, BertOnnxConfig
from transformers.models.distilbert import DistilBertOnnxConfig
# from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
from transformers.onnx.utils import (
compute_effective_axis_dimension,
compute_serialized_parameters_size,
flatten_output_collection_property,
)
from transformers.testing_utils import require_onnx, require_torch, slow
@require_onnx
class OnnxUtilsTestCaseV2(TestCase):
"""
Cover all the utilities involved to export ONNX models
"""
def test_compute_effective_axis_dimension(self):
"""
When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.
We cannot generate an effective tensor with axis dim == -1, so we trick by using some "fixed" values
(> 1 to avoid ONNX squeezing the axis).
This test ensure we are correctly replacing generated batch / sequence tensor with axis > 1
"""
# Dynamic axis (batch, no token added by the tokenizer)
self.assertEqual(compute_effective_axis_dimension(-1, fixed_dimension=2, num_token_to_add=0), 2)
# Static axis (batch, no token added by the tokenizer)
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=2, num_token_to_add=0), 2)
# Dynamic axis (sequence, token added by the tokenizer 2 (no pair))
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=2), 6)
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=2), 6)
# Dynamic axis (sequence, token added by the tokenizer 3 (pair))
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=3), 5)
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=3), 5)
def test_compute_parameters_serialized_size(self):
"""
This test ensures we compute a "correct" approximation of the underlying storage requirement (size) for all the
parameters for the specified parameter's dtype.
"""
self.assertEqual(compute_serialized_parameters_size(2, ParameterFormat.Float), 2 * ParameterFormat.Float.size)
def test_flatten_output_collection_property(self):
"""
This test ensures we correctly flatten nested collection such as the one we use when returning past_keys.
past_keys = Tuple[Tuple]
ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
"""
self.assertEqual(
flatten_output_collection_property("past_key", [[0], [1], [2]]),
{
"past_key.0": 0,
"past_key.1": 1,
"past_key.2": 2,
},
)
class OnnxConfigTestCaseV2(TestCase):
"""
Cover the test for models default.
Default means no specific features is being enabled on the model.
"""
@patch.multiple(OnnxConfig, __abstractmethods__=set())
def test_use_external_data_format(self):
"""
External data format is required only if the serialized size of the parameters if bigger than 2Gb
"""
TWO_GB_LIMIT = EXTERNAL_DATA_FORMAT_SIZE_LIMIT
# No parameters
self.assertFalse(OnnxConfig.use_external_data_format(0))
# Some parameters
self.assertFalse(OnnxConfig.use_external_data_format(1))
# Almost 2Gb parameters
self.assertFalse(OnnxConfig.use_external_data_format((TWO_GB_LIMIT - 1) // ParameterFormat.Float.size))
# Exactly 2Gb parameters
self.assertTrue(OnnxConfig.use_external_data_format(TWO_GB_LIMIT))
# More than 2Gb parameters
self.assertTrue(OnnxConfig.use_external_data_format((TWO_GB_LIMIT + 1) // ParameterFormat.Float.size))
class OnnxConfigWithPastTestCaseV2(TestCase):
"""
Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX)
"""
SUPPORTED_WITH_PAST_CONFIGS = {("BART", BartConfig), ("GPT2", GPT2Config), ("T5", T5Config)}
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
def test_use_past(self):
"""
Ensure the use_past variable is correctly being set
"""
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
with self.subTest(name):
self.assertFalse(
OnnxConfigWithPast.default(config()).use_past, "OnnxConfigWithPast.default() should not use_past"
)
self.assertTrue(
OnnxConfigWithPast.with_past(config()).use_past, "OnnxConfigWithPast.default() should use_past"
)
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
def test_values_override(self):
"""
Ensure the use_past variable correctly set the `use_cache` value in model's configuration
"""
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
with self.subTest(name):
# without past
onnx_config_default = OnnxConfigWithPast.default(config())
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
self.assertFalse(
onnx_config_default.values_override["use_cache"], "use_cache should be False if not using past"
)
# with past
onnx_config_default = OnnxConfigWithPast.with_past(config())
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
self.assertTrue(
onnx_config_default.values_override["use_cache"], "use_cache should be False if not using past"
)
if is_torch_available():
from transformers import (
AlbertModel,
BartModel,
BertModel,
DistilBertModel,
GPT2Model,
RobertaModel,
T5Model,
XLMRobertaModel,
)
PYTORCH_EXPORT_DEFAULT_MODELS = {
("ALBERT", "albert-base-v2", AlbertModel, AlbertConfig, AlbertOnnxConfig),
("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
}
PYTORCH_EXPORT_WITH_PAST_MODELS = {
# ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
# ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig)
}
class OnnxExportTestCaseV2(TestCase):
"""
Integration tests ensuring supported models are correctly exported
"""
@slow
@require_torch
def test_pytorch_export_default(self):
from transformers.onnx import export
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
with self.subTest(name):
self.assertTrue(hasattr(onnx_config_class, "default"))
tokenizer = AutoTokenizer.from_pretrained(model)
model = model_class(config_class())
onnx_config = onnx_config_class.default(model.config)
with NamedTemporaryFile("w") as output:
onnx_inputs, onnx_outputs = export(
tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name)
)
try:
validate_model_outputs(onnx_config, tokenizer, model, Path(output.name), onnx_outputs, 1e-5)
except ValueError as ve:
self.fail(f"{name} -> {ve}")
@slow
@require_torch
def test_pytorch_export_with_past(self):
from transformers.onnx import export
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_WITH_PAST_MODELS:
with self.subTest(name):
self.assertTrue(hasattr(onnx_config_class, "with_past"), "OnnxConfigWithPast should have with_past()")
tokenizer = AutoTokenizer.from_pretrained(model)
model = model_class(config_class())
onnx_config = onnx_config_class.with_past(model.config)
self.assertTrue(hasattr(onnx_config, "use_past"), "OnnxConfigWithPast should have use_past attribute.")
self.assertTrue(
onnx_config.use_past, "OnnxConfigWithPast.use_past should be if called with with_past()"
)
with NamedTemporaryFile("w") as output:
output = Path(output.name)
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, output)
try:
validate_model_outputs(onnx_config, tokenizer, model, output, onnx_outputs, 1e-5)
except ValueError as ve:
self.fail(f"{name} -> {ve}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment