"vscode:/vscode.git/clone" did not exist on "93f31e0e78e6d4bc7341ff3d34d60d78dafe1128"
Unverified Commit a6d62aab authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

GPT-Neo ONNX export (#12911)



GPT-Neo ONNX export and task / feature refactoring
Authored-by: default avatarMichael Benayoun <michael@huggingface.co>
parent 8aa01d2a
...@@ -21,7 +21,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_torch_available ...@@ -21,7 +21,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_torch_available
_import_structure = { _import_structure = {
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"], "configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig", "GPTNeoOnnxConfig"],
} }
if is_torch_available(): if is_torch_available():
...@@ -43,7 +43,7 @@ if is_flax_available(): ...@@ -43,7 +43,7 @@ if is_flax_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig, GPTNeoOnnxConfig
if is_torch_available(): if is_torch_available():
from .modeling_gpt_neo import ( from .modeling_gpt_neo import (
......
...@@ -14,7 +14,12 @@ ...@@ -14,7 +14,12 @@
# limitations under the License. # limitations under the License.
""" GPT Neo model configuration """ """ GPT Neo model configuration """
from collections import OrderedDict
from typing import Any, Mapping, Optional
from ... import PreTrainedTokenizer, TensorType, is_torch_available
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast, PatchingSpec
from ...utils import logging from ...utils import logging
...@@ -173,3 +178,140 @@ class GPTNeoConfig(PretrainedConfig): ...@@ -173,3 +178,140 @@ class GPTNeoConfig(PretrainedConfig):
@property @property
def num_hidden_layers(self): def num_hidden_layers(self):
return self.num_layers return self.num_layers
def custom_unfold(input, dimension, size, step):
"""Custom torch.Tensor.unfold implementation to enable the export to ONNX."""
import torch
shape = input.size()
rank = len(shape)
sizedim = shape[dimension]
low_indices = torch.arange(0, sizedim, step)
min_length = torch.div(sizedim - size, step, rounding_mode="floor") + 1
indices = torch.arange(size) + low_indices[:min_length][:, None]
s = [slice(None)] * rank
s[dimension] = indices
sliced = input[s]
perm = list(range(0, rank + 1))
perm.append(perm.pop(dimension + 1))
return sliced.permute(perm)
def custom_get_block_length_and_num_blocks(seq_length, window_size):
"""
Custom implementation for GPTNeoAttentionMixin._get_block_length_and_num_blocks to enable the export to ONNX as
original implmentation uses Python variables and control flow.
"""
import torch
candidates = torch.arange(1, window_size)
remainders = torch.remainder(seq_length, candidates)
divisor_indices = remainders == 0
divisors = candidates[divisor_indices]
largest_divisor = torch.max(divisors)
return largest_divisor, torch.div(seq_length, largest_divisor, rounding_mode="floor")
class GPTNeoOnnxConfig(OnnxConfigWithPast):
def __init__(self, config: PretrainedConfig, task: str = "default", use_past: bool = False):
if is_torch_available():
import torch
from .modeling_gpt_neo import GPTNeoAttentionMixin
patching_specs = [
PatchingSpec(torch.Tensor, name="unfold", custom_op=custom_unfold),
PatchingSpec(
GPTNeoAttentionMixin,
name="_get_block_length_and_num_blocks",
custom_op=custom_get_block_length_and_num_blocks,
op_wrapper=staticmethod,
),
]
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
self._num_local_attention = len([type_ for type_ in self._config.attention_layers if type_ == "local"])
self._key_values_dynamic_axis = []
for i in range(self._config.num_layers):
if self._config.attention_layers[i] == "local":
self._key_values_dynamic_axis.append({0: "batch", 1: "sequence"})
else:
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
@property
def _number_key_values(self):
return (self._config.num_layers * 2) - self._num_local_attention
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
for i in range(self._number_key_values):
common_inputs[f"past_key_values.{i}"] = self._key_values_dynamic_axis[i]
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
return common_inputs
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = super().outputs
if self.use_past:
for i in range(self._number_key_values):
common_outputs[f"present.{i}"] = self._key_values_dynamic_axis[i]
return common_outputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
batch = common_inputs["input_ids"].shape[0]
past_shapes = {
"global": (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_attention_heads),
"local": (batch, 1, self._config.hidden_size),
}
# Need to add the past_keys
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
ordered_inputs["past_key_values"] = []
for i in range(self._config.num_layers):
attention_type = self._config.attention_layers[i]
if attention_type == "global":
ordered_inputs["past_key_values"].append(
(
torch.zeros(past_shapes[attention_type]),
torch.zeros(past_shapes[attention_type]),
)
)
else:
ordered_inputs["past_key_values"].append((torch.zeros(past_shapes[attention_type]),))
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.zeros(batch, 1)], dim=1
)
return ordered_inputs
...@@ -1121,7 +1121,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel): ...@@ -1121,7 +1121,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
) )
pooled_logits = logits[range(batch_size), sequence_lengths] pooled_logits = logits[torch.arange(batch_size), sequence_lengths]
loss = None loss = None
if labels is not None: if labels is not None:
......
...@@ -13,6 +13,6 @@ ...@@ -13,6 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast, PatchingSpec
from .convert import export, validate_model_outputs from .convert import export, validate_model_outputs
from .utils import ParameterFormat, compute_serialized_parameters_size from .utils import ParameterFormat, compute_serialized_parameters_size
...@@ -14,101 +14,22 @@ ...@@ -14,101 +14,22 @@
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from typing import Callable, Tuple
from transformers.models.albert import AlbertOnnxConfig
from transformers.models.auto import AutoTokenizer from transformers.models.auto import AutoTokenizer
from transformers.models.bart import BartOnnxConfig
from transformers.models.bert import BertOnnxConfig
from transformers.models.distilbert import DistilBertOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from .. import is_torch_available
from ..utils import logging from ..utils import logging
from .convert import export, validate_model_outputs from .convert import export, validate_model_outputs
from .features import FeaturesManager
if is_torch_available():
from transformers import AutoModel, PreTrainedModel
FEATURES_TO_AUTOMODELS = {
"default": AutoModel,
}
# Set of model topologies we support associated to the features supported by each topology and the factory
SUPPORTED_MODEL_KIND = {
"albert": {"default": AlbertOnnxConfig.default},
"bart": {"default": BartOnnxConfig.default},
"bert": {"default": BertOnnxConfig.default},
"distilbert": {"default": DistilBertOnnxConfig.default},
"gpt2": {"default": GPT2OnnxConfig.default},
"longformer": {"default": LongformerOnnxConfig.default},
"roberta": {"default": RobertaOnnxConfig},
"t5": {"default": T5OnnxConfig.default},
"xlm-roberta": {"default": XLMRobertaOnnxConfig.default},
}
def get_model_from_features(features: str, model: str):
"""
Attempt to retrieve a model from a model's name and the features to be enabled.
Args:
features: The features required
model: The name of the model to export
Returns:
"""
if features not in FEATURES_TO_AUTOMODELS:
raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}")
return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)
def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features
Args:
model: The model to export
features: The name of the features to check if they are avaiable
Returns:
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
"""
if model.config.model_type not in SUPPORTED_MODEL_KIND:
raise KeyError(
f"{model.config.model_type} ({model.name}) is not supported yet. "
f"Only {SUPPORTED_MODEL_KIND} are supported. "
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
)
# Look for the features
model_features = SUPPORTED_MODEL_KIND[model.config.model_type]
if features not in model_features:
raise ValueError(
f"{model.config.model_type} doesn't support features {features}. "
f"Supported values are: {list(model_features.keys())}"
)
return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features]
def main(): def main():
parser = ArgumentParser("Hugging Face ONNX Exporter tool") parser = ArgumentParser("Hugging Face ONNX Exporter tool")
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.") parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
parser.add_argument( parser.add_argument(
"--features", "--feature",
choices=["default"], choices=list(FeaturesManager.AVAILABLE_FEATURES),
default="default", default="default",
help="Export the model with some additional features.", help="Export the model with some additional feature.",
) )
parser.add_argument( parser.add_argument(
"--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)." "--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
...@@ -127,8 +48,8 @@ def main(): ...@@ -127,8 +48,8 @@ def main():
# Allocate the model # Allocate the model
tokenizer = AutoTokenizer.from_pretrained(args.model) tokenizer = AutoTokenizer.from_pretrained(args.model)
model = get_model_from_features(args.features, args.model) model = FeaturesManager.get_model_from_feature(args.feature, args.model)
model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features) model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
onnx_config = model_onnx_config(model.config) onnx_config = model_onnx_config(model.config)
# Ensure the requested opset is sufficient # Ensure the requested opset is sufficient
......
...@@ -11,9 +11,10 @@ ...@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import dataclasses
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Mapping, Optional from typing import Any, Callable, List, Mapping, Optional
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
...@@ -26,6 +27,27 @@ DEFAULT_ONNX_OPSET = 11 ...@@ -26,6 +27,27 @@ DEFAULT_ONNX_OPSET = 11
EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024 EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024
@dataclasses.dataclass
class PatchingSpec:
"""
Data class that holds patching specifications.
Args:
o: Module / object where the op to patch is located
name: Name of the op to monkey patch
custom_op: Custom op that patches the original op
orig_op: Original op that is being patched
op_wrapper: Wrapper (optional) that wraps both the original and custom ops.
It is useful for ops that are class or static methods for instance.
"""
o: Any
name: str
custom_op: Callable
orig_op: Optional[Callable] = None
op_wrapper: Optional[Callable] = None
class OnnxConfig(ABC): class OnnxConfig(ABC):
""" """
Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format. Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
...@@ -34,11 +56,38 @@ class OnnxConfig(ABC): ...@@ -34,11 +56,38 @@ class OnnxConfig(ABC):
DEFAULT_FIXED_BATCH = 2 DEFAULT_FIXED_BATCH = 2
DEFAULT_FIXED_SEQUENCE = 8 DEFAULT_FIXED_SEQUENCE = 8
def __init__(self, config: PretrainedConfig): _TASKS_TO_COMMON_OUTPUTS = {
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
"question-answering": OrderedDict(
{
"start_logits": {0: "batch", 1: "sequence"},
"end_logits": {0: "batch", 1: "sequence"},
}
),
}
def __init__(self, config: PretrainedConfig, task: str = "default", patching_specs: List[PatchingSpec] = None):
self._config = config self._config = config
if task not in self._TASKS_TO_COMMON_OUTPUTS:
raise ValueError(
f"{task} is not a supported task, supported tasks: {self._TASKS_TO_COMMON_OUTPUTS.keys()}"
)
self.task = task
self._patching_specs = []
for spec in patching_specs if patching_specs is not None else []:
final_spec = spec
if spec.orig_op is None:
final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))
self._patching_specs.append(final_spec)
@classmethod @classmethod
def default(cls, config: PretrainedConfig) -> "OnnxConfig": def from_model_config(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfig":
""" """
Instantiate a OnnxConfig for a specific model Instantiate a OnnxConfig for a specific model
...@@ -48,7 +97,7 @@ class OnnxConfig(ABC): ...@@ -48,7 +97,7 @@ class OnnxConfig(ABC):
Returns: Returns:
OnnxConfig for this model OnnxConfig for this model
""" """
return cls(config) return cls(config, task=task)
@property @property
@abstractmethod @abstractmethod
...@@ -62,7 +111,6 @@ class OnnxConfig(ABC): ...@@ -62,7 +111,6 @@ class OnnxConfig(ABC):
raise NotImplementedError() raise NotImplementedError()
@property @property
@abstractmethod
def outputs(self) -> Mapping[str, Mapping[int, str]]: def outputs(self) -> Mapping[str, Mapping[int, str]]:
""" """
Mapping containing the axis definition of the output tensors to provide to the model Mapping containing the axis definition of the output tensors to provide to the model
...@@ -70,7 +118,7 @@ class OnnxConfig(ABC): ...@@ -70,7 +118,7 @@ class OnnxConfig(ABC):
Returns: Returns:
For each output: its name associated to the axes symbolic name and the axis position within the tensor For each output: its name associated to the axes symbolic name and the axis position within the tensor
""" """
raise NotImplementedError() return self._TASKS_TO_COMMON_OUTPUTS[self.task]
@property @property
def values_override(self) -> Optional[Mapping[str, Any]]: def values_override(self) -> Optional[Mapping[str, Any]]:
...@@ -170,14 +218,30 @@ class OnnxConfig(ABC): ...@@ -170,14 +218,30 @@ class OnnxConfig(ABC):
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
return dict(tokenizer(dummy_input, return_tensors=framework)) return dict(tokenizer(dummy_input, return_tensors=framework))
def patch_ops(self):
for spec in self._patching_specs:
custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op)
setattr(spec.o, spec.name, custom_op)
def restore_ops(self):
for spec in self._patching_specs:
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
setattr(spec.o, spec.name, orig_op)
class OnnxConfigWithPast(OnnxConfig, ABC): class OnnxConfigWithPast(OnnxConfig, ABC):
def __init__(self, config: PretrainedConfig, use_past: bool = False): def __init__(
super().__init__(config) self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
use_past: bool = False,
):
super().__init__(config, task=task, patching_specs=patching_specs)
self.use_past = use_past self.use_past = use_past
@classmethod @classmethod
def with_past(cls, config: PretrainedConfig) -> "OnnxConfigWithPast": def with_past(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfigWithPast":
""" """
Instantiate a OnnxConfig with `use_past` attribute set to True Instantiate a OnnxConfig with `use_past` attribute set to True
...@@ -187,7 +251,7 @@ class OnnxConfigWithPast(OnnxConfig, ABC): ...@@ -187,7 +251,7 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
Returns: Returns:
OnnxConfig with `.use_past = True` OnnxConfig with `.use_past = True`
""" """
return cls(config, use_past=True) return cls(config, task=task, use_past=True)
@property @property
def values_override(self) -> Optional[Mapping[str, Any]]: def values_override(self) -> Optional[Mapping[str, Any]]:
......
...@@ -111,6 +111,8 @@ def export( ...@@ -111,6 +111,8 @@ def export(
if not inputs_match: if not inputs_match:
raise ValueError("Model and config inputs doesn't match") raise ValueError("Model and config inputs doesn't match")
config.patch_ops()
# export can works with named args but the dict containing named args as to be last element of the args tuple # export can works with named args but the dict containing named args as to be last element of the args tuple
export( export(
model, model,
...@@ -125,6 +127,8 @@ def export( ...@@ -125,6 +127,8 @@ def export(
opset_version=opset, opset_version=opset,
) )
config.restore_ops()
return matched_inputs, onnx_outputs return matched_inputs, onnx_outputs
...@@ -140,6 +144,8 @@ def validate_model_outputs( ...@@ -140,6 +144,8 @@ def validate_model_outputs(
logger.info("Validating ONNX model...") logger.info("Validating ONNX model...")
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
# dynamic input shapes.
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
# Create ONNX Runtime session # Create ONNX Runtime session
...@@ -152,6 +158,10 @@ def validate_model_outputs( ...@@ -152,6 +158,10 @@ def validate_model_outputs(
# We flatten potential collection of outputs (i.e. past_keys) to a flat structure # We flatten potential collection of outputs (i.e. past_keys) to a flat structure
for name, value in ref_outputs.items(): for name, value in ref_outputs.items():
# Overwriting the output name as "present" since it is the name used for the ONNX ouputs
# ("past_key_values" being taken for the ONNX inputs)
if name == "past_key_values":
name = "present"
if isinstance(value, (list, tuple)): if isinstance(value, (list, tuple)):
value = flatten_output_collection_property(name, value) value = flatten_output_collection_property(name, value)
ref_outputs_dict.update(value) ref_outputs_dict.update(value)
...@@ -186,7 +196,7 @@ def validate_model_outputs( ...@@ -186,7 +196,7 @@ def validate_model_outputs(
# Check the shape and values match # Check the shape and values match
for name, ort_value in zip(onnx_named_outputs, onnx_outputs): for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
ref_value = ref_outputs_dict[name].numpy() ref_value = ref_outputs_dict[name].detach().numpy()
logger.info(f'\t- Validating ONNX Model output "{name}":') logger.info(f'\t- Validating ONNX Model output "{name}":')
# Shape # Shape
...@@ -197,7 +207,7 @@ def validate_model_outputs( ...@@ -197,7 +207,7 @@ def validate_model_outputs(
f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)" f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)"
) )
else: else:
logger.info(f"\t\t-[✓] {ort_value.shape} matchs {ref_value.shape}") logger.info(f"\t\t-[✓] {ort_value.shape} matches {ref_value.shape}")
# Values # Values
if not np.allclose(ref_value, ort_value, atol=atol): if not np.allclose(ref_value, ort_value, atol=atol):
......
from functools import partial, reduce
from typing import Callable, Tuple
from .. import is_torch_available
from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig
from ..models.bert import BertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.longformer import LongformerOnnxConfig
from ..models.roberta import RobertaOnnxConfig
from ..models.t5 import T5OnnxConfig
from ..models.xlm_roberta import XLMRobertaOnnxConfig
if is_torch_available():
from transformers import PreTrainedModel
from transformers.models.auto import (
AutoModel,
AutoModelForCausalLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)
def supported_features_mapping(*supported_features, onnx_config_cls=None):
"""Generates the mapping between supported features and their corresponding OnnxConfig."""
if onnx_config_cls is None:
raise ValueError("A OnnxConfig class must be provided")
mapping = {}
for feature in supported_features:
if "-with-past" in feature:
task = feature.replace("-with-past", "")
mapping[feature] = partial(onnx_config_cls.with_past, task=task)
else:
mapping[feature] = partial(onnx_config_cls.from_model_config, task=feature)
return mapping
class FeaturesManager:
_TASKS_TO_AUTOMODELS = {
"default": AutoModel,
"causal-lm": AutoModelForCausalLM,
"sequence-classification": AutoModelForSequenceClassification,
"token-classification": AutoModelForTokenClassification,
"multiple-choice": AutoModelForMultipleChoice,
"question-answering": AutoModelForQuestionAnswering,
}
# Set of model topologies we support associated to the features supported by each topology and the factory
_SUPPORTED_MODEL_KIND = {
"albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig),
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
"roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig),
"t5": supported_features_mapping("default", onnx_config_cls=T5OnnxConfig),
"xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig),
"gpt-neo": supported_features_mapping(
"default",
"causal-lm",
"sequence-classification",
"default-with-past",
"causal-lm-with-past",
"sequence-classification-with-past",
onnx_config_cls=GPTNeoOnnxConfig,
),
}
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values())))
@staticmethod
def feature_to_task(feature: str) -> str:
return feature.replace("-with-past", "")
@staticmethod
def get_model_from_feature(feature: str, model: str):
"""
Attempt to retrieve a model from a model's name and the feature to be enabled.
Args:
feature: The feature required
model: The name of the model to export
Returns:
"""
task = FeaturesManager.feature_to_task(feature)
if task not in FeaturesManager._TASKS_TO_AUTOMODELS:
raise KeyError(
f"Unknown task: {feature}."
f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
)
return FeaturesManager._TASKS_TO_AUTOMODELS[task].from_pretrained(model)
@staticmethod
def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features
Args:
model: The model to export
feature: The name of the feature to check if it is avaiable
Returns:
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
"""
model_type = model.config.model_type.replace("_", "-")
model_name = getattr(model, "name", "")
model_name = f"({model_name})" if model_name else ""
if model_type not in FeaturesManager._SUPPORTED_MODEL_KIND:
raise KeyError(
f"{model.config.model_type} ({model_name}) is not supported yet. "
f"Only {FeaturesManager._SUPPORTED_MODEL_KIND} are supported. "
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
)
# Look for the features
model_features = FeaturesManager._SUPPORTED_MODEL_KIND[model_type]
if feature not in model_features:
raise ValueError(
f"{model.config.model_type} doesn't support feature {feature}. "
f"Supported values are: {list(model_features.keys())}"
)
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_KIND[model_type][feature]
...@@ -9,6 +9,7 @@ from transformers import ( # LongformerConfig,; T5Config, ...@@ -9,6 +9,7 @@ from transformers import ( # LongformerConfig,; T5Config,
BartConfig, BartConfig,
DistilBertConfig, DistilBertConfig,
GPT2Config, GPT2Config,
GPTNeoConfig,
RobertaConfig, RobertaConfig,
XLMRobertaConfig, XLMRobertaConfig,
is_torch_available, is_torch_available,
...@@ -20,6 +21,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig ...@@ -20,6 +21,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig
# from transformers.models.longformer import LongformerOnnxConfig # from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig from transformers.models.roberta import RobertaOnnxConfig
# from transformers.models.t5 import T5OnnxConfig # from transformers.models.t5 import T5OnnxConfig
...@@ -151,7 +153,8 @@ class OnnxConfigWithPastTestCaseV2(TestCase): ...@@ -151,7 +153,8 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS: for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
with self.subTest(name): with self.subTest(name):
self.assertFalse( self.assertFalse(
OnnxConfigWithPast.default(config()).use_past, "OnnxConfigWithPast.default() should not use_past" OnnxConfigWithPast.from_model_config(config()).use_past,
"OnnxConfigWithPast.from_model_config() should not use_past",
) )
self.assertTrue( self.assertTrue(
...@@ -167,7 +170,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase): ...@@ -167,7 +170,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
with self.subTest(name): with self.subTest(name):
# without past # without past
onnx_config_default = OnnxConfigWithPast.default(config()) onnx_config_default = OnnxConfigWithPast.from_model_config(config())
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None") self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present") self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
self.assertFalse( self.assertFalse(
...@@ -190,6 +193,7 @@ if is_torch_available(): ...@@ -190,6 +193,7 @@ if is_torch_available():
BertModel, BertModel,
DistilBertModel, DistilBertModel,
GPT2Model, GPT2Model,
GPTNeoModel,
RobertaModel, RobertaModel,
XLMRobertaModel, XLMRobertaModel,
) )
...@@ -200,6 +204,7 @@ if is_torch_available(): ...@@ -200,6 +204,7 @@ if is_torch_available():
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig), ("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig), ("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig), ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
("GPT-Neo", "EleutherAI/gpt-neo-125M", GPTNeoModel, GPTNeoConfig, GPTNeoOnnxConfig),
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig), # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig), ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig), ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment