Unverified Commit a6d62aab authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

GPT-Neo ONNX export (#12911)



GPT-Neo ONNX export and task / feature refactoring
Authored-by: default avatarMichael Benayoun <michael@huggingface.co>
parent 8aa01d2a
......@@ -21,7 +21,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_torch_available
_import_structure = {
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig", "GPTNeoOnnxConfig"],
}
if is_torch_available():
......@@ -43,7 +43,7 @@ if is_flax_available():
if TYPE_CHECKING:
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig, GPTNeoOnnxConfig
if is_torch_available():
from .modeling_gpt_neo import (
......
......@@ -14,7 +14,12 @@
# limitations under the License.
""" GPT Neo model configuration """
from collections import OrderedDict
from typing import Any, Mapping, Optional
from ... import PreTrainedTokenizer, TensorType, is_torch_available
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast, PatchingSpec
from ...utils import logging
......@@ -173,3 +178,140 @@ class GPTNeoConfig(PretrainedConfig):
@property
def num_hidden_layers(self):
return self.num_layers
def custom_unfold(input, dimension, size, step):
"""Custom torch.Tensor.unfold implementation to enable the export to ONNX."""
import torch
shape = input.size()
rank = len(shape)
sizedim = shape[dimension]
low_indices = torch.arange(0, sizedim, step)
min_length = torch.div(sizedim - size, step, rounding_mode="floor") + 1
indices = torch.arange(size) + low_indices[:min_length][:, None]
s = [slice(None)] * rank
s[dimension] = indices
sliced = input[s]
perm = list(range(0, rank + 1))
perm.append(perm.pop(dimension + 1))
return sliced.permute(perm)
def custom_get_block_length_and_num_blocks(seq_length, window_size):
"""
Custom implementation for GPTNeoAttentionMixin._get_block_length_and_num_blocks to enable the export to ONNX as
original implmentation uses Python variables and control flow.
"""
import torch
candidates = torch.arange(1, window_size)
remainders = torch.remainder(seq_length, candidates)
divisor_indices = remainders == 0
divisors = candidates[divisor_indices]
largest_divisor = torch.max(divisors)
return largest_divisor, torch.div(seq_length, largest_divisor, rounding_mode="floor")
class GPTNeoOnnxConfig(OnnxConfigWithPast):
def __init__(self, config: PretrainedConfig, task: str = "default", use_past: bool = False):
if is_torch_available():
import torch
from .modeling_gpt_neo import GPTNeoAttentionMixin
patching_specs = [
PatchingSpec(torch.Tensor, name="unfold", custom_op=custom_unfold),
PatchingSpec(
GPTNeoAttentionMixin,
name="_get_block_length_and_num_blocks",
custom_op=custom_get_block_length_and_num_blocks,
op_wrapper=staticmethod,
),
]
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
self._num_local_attention = len([type_ for type_ in self._config.attention_layers if type_ == "local"])
self._key_values_dynamic_axis = []
for i in range(self._config.num_layers):
if self._config.attention_layers[i] == "local":
self._key_values_dynamic_axis.append({0: "batch", 1: "sequence"})
else:
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
@property
def _number_key_values(self):
return (self._config.num_layers * 2) - self._num_local_attention
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
for i in range(self._number_key_values):
common_inputs[f"past_key_values.{i}"] = self._key_values_dynamic_axis[i]
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
return common_inputs
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = super().outputs
if self.use_past:
for i in range(self._number_key_values):
common_outputs[f"present.{i}"] = self._key_values_dynamic_axis[i]
return common_outputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
batch = common_inputs["input_ids"].shape[0]
past_shapes = {
"global": (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_attention_heads),
"local": (batch, 1, self._config.hidden_size),
}
# Need to add the past_keys
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
ordered_inputs["past_key_values"] = []
for i in range(self._config.num_layers):
attention_type = self._config.attention_layers[i]
if attention_type == "global":
ordered_inputs["past_key_values"].append(
(
torch.zeros(past_shapes[attention_type]),
torch.zeros(past_shapes[attention_type]),
)
)
else:
ordered_inputs["past_key_values"].append((torch.zeros(past_shapes[attention_type]),))
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.zeros(batch, 1)], dim=1
)
return ordered_inputs
......@@ -1121,7 +1121,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size), sequence_lengths]
loss = None
if labels is not None:
......
......@@ -13,6 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast, PatchingSpec
from .convert import export, validate_model_outputs
from .utils import ParameterFormat, compute_serialized_parameters_size
......@@ -14,101 +14,22 @@
from argparse import ArgumentParser
from pathlib import Path
from typing import Callable, Tuple
from transformers.models.albert import AlbertOnnxConfig
from transformers.models.auto import AutoTokenizer
from transformers.models.bart import BartOnnxConfig
from transformers.models.bert import BertOnnxConfig
from transformers.models.distilbert import DistilBertOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from .. import is_torch_available
from ..utils import logging
from .convert import export, validate_model_outputs
if is_torch_available():
from transformers import AutoModel, PreTrainedModel
FEATURES_TO_AUTOMODELS = {
"default": AutoModel,
}
# Set of model topologies we support associated to the features supported by each topology and the factory
SUPPORTED_MODEL_KIND = {
"albert": {"default": AlbertOnnxConfig.default},
"bart": {"default": BartOnnxConfig.default},
"bert": {"default": BertOnnxConfig.default},
"distilbert": {"default": DistilBertOnnxConfig.default},
"gpt2": {"default": GPT2OnnxConfig.default},
"longformer": {"default": LongformerOnnxConfig.default},
"roberta": {"default": RobertaOnnxConfig},
"t5": {"default": T5OnnxConfig.default},
"xlm-roberta": {"default": XLMRobertaOnnxConfig.default},
}
def get_model_from_features(features: str, model: str):
"""
Attempt to retrieve a model from a model's name and the features to be enabled.
Args:
features: The features required
model: The name of the model to export
Returns:
"""
if features not in FEATURES_TO_AUTOMODELS:
raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}")
return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)
def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features
Args:
model: The model to export
features: The name of the features to check if they are avaiable
Returns:
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
"""
if model.config.model_type not in SUPPORTED_MODEL_KIND:
raise KeyError(
f"{model.config.model_type} ({model.name}) is not supported yet. "
f"Only {SUPPORTED_MODEL_KIND} are supported. "
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
)
# Look for the features
model_features = SUPPORTED_MODEL_KIND[model.config.model_type]
if features not in model_features:
raise ValueError(
f"{model.config.model_type} doesn't support features {features}. "
f"Supported values are: {list(model_features.keys())}"
)
return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features]
from .features import FeaturesManager
def main():
parser = ArgumentParser("Hugging Face ONNX Exporter tool")
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
parser.add_argument(
"--features",
choices=["default"],
"--feature",
choices=list(FeaturesManager.AVAILABLE_FEATURES),
default="default",
help="Export the model with some additional features.",
help="Export the model with some additional feature.",
)
parser.add_argument(
"--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
......@@ -127,8 +48,8 @@ def main():
# Allocate the model
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = get_model_from_features(args.features, args.model)
model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features)
model = FeaturesManager.get_model_from_feature(args.feature, args.model)
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
onnx_config = model_onnx_config(model.config)
# Ensure the requested opset is sufficient
......
......@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Mapping, Optional
from typing import Any, Callable, List, Mapping, Optional
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
......@@ -26,6 +27,27 @@ DEFAULT_ONNX_OPSET = 11
EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024
@dataclasses.dataclass
class PatchingSpec:
"""
Data class that holds patching specifications.
Args:
o: Module / object where the op to patch is located
name: Name of the op to monkey patch
custom_op: Custom op that patches the original op
orig_op: Original op that is being patched
op_wrapper: Wrapper (optional) that wraps both the original and custom ops.
It is useful for ops that are class or static methods for instance.
"""
o: Any
name: str
custom_op: Callable
orig_op: Optional[Callable] = None
op_wrapper: Optional[Callable] = None
class OnnxConfig(ABC):
"""
Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
......@@ -34,11 +56,38 @@ class OnnxConfig(ABC):
DEFAULT_FIXED_BATCH = 2
DEFAULT_FIXED_SEQUENCE = 8
def __init__(self, config: PretrainedConfig):
_TASKS_TO_COMMON_OUTPUTS = {
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
"question-answering": OrderedDict(
{
"start_logits": {0: "batch", 1: "sequence"},
"end_logits": {0: "batch", 1: "sequence"},
}
),
}
def __init__(self, config: PretrainedConfig, task: str = "default", patching_specs: List[PatchingSpec] = None):
self._config = config
if task not in self._TASKS_TO_COMMON_OUTPUTS:
raise ValueError(
f"{task} is not a supported task, supported tasks: {self._TASKS_TO_COMMON_OUTPUTS.keys()}"
)
self.task = task
self._patching_specs = []
for spec in patching_specs if patching_specs is not None else []:
final_spec = spec
if spec.orig_op is None:
final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))
self._patching_specs.append(final_spec)
@classmethod
def default(cls, config: PretrainedConfig) -> "OnnxConfig":
def from_model_config(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfig":
"""
Instantiate a OnnxConfig for a specific model
......@@ -48,7 +97,7 @@ class OnnxConfig(ABC):
Returns:
OnnxConfig for this model
"""
return cls(config)
return cls(config, task=task)
@property
@abstractmethod
......@@ -62,7 +111,6 @@ class OnnxConfig(ABC):
raise NotImplementedError()
@property
@abstractmethod
def outputs(self) -> Mapping[str, Mapping[int, str]]:
"""
Mapping containing the axis definition of the output tensors to provide to the model
......@@ -70,7 +118,7 @@ class OnnxConfig(ABC):
Returns:
For each output: its name associated to the axes symbolic name and the axis position within the tensor
"""
raise NotImplementedError()
return self._TASKS_TO_COMMON_OUTPUTS[self.task]
@property
def values_override(self) -> Optional[Mapping[str, Any]]:
......@@ -170,14 +218,30 @@ class OnnxConfig(ABC):
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
return dict(tokenizer(dummy_input, return_tensors=framework))
def patch_ops(self):
for spec in self._patching_specs:
custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op)
setattr(spec.o, spec.name, custom_op)
def restore_ops(self):
for spec in self._patching_specs:
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
setattr(spec.o, spec.name, orig_op)
class OnnxConfigWithPast(OnnxConfig, ABC):
def __init__(self, config: PretrainedConfig, use_past: bool = False):
super().__init__(config)
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
use_past: bool = False,
):
super().__init__(config, task=task, patching_specs=patching_specs)
self.use_past = use_past
@classmethod
def with_past(cls, config: PretrainedConfig) -> "OnnxConfigWithPast":
def with_past(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfigWithPast":
"""
Instantiate a OnnxConfig with `use_past` attribute set to True
......@@ -187,7 +251,7 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
Returns:
OnnxConfig with `.use_past = True`
"""
return cls(config, use_past=True)
return cls(config, task=task, use_past=True)
@property
def values_override(self) -> Optional[Mapping[str, Any]]:
......
......@@ -111,6 +111,8 @@ def export(
if not inputs_match:
raise ValueError("Model and config inputs doesn't match")
config.patch_ops()
# export can works with named args but the dict containing named args as to be last element of the args tuple
export(
model,
......@@ -125,6 +127,8 @@ def export(
opset_version=opset,
)
config.restore_ops()
return matched_inputs, onnx_outputs
......@@ -140,6 +144,8 @@ def validate_model_outputs(
logger.info("Validating ONNX model...")
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
# dynamic input shapes.
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
# Create ONNX Runtime session
......@@ -152,6 +158,10 @@ def validate_model_outputs(
# We flatten potential collection of outputs (i.e. past_keys) to a flat structure
for name, value in ref_outputs.items():
# Overwriting the output name as "present" since it is the name used for the ONNX ouputs
# ("past_key_values" being taken for the ONNX inputs)
if name == "past_key_values":
name = "present"
if isinstance(value, (list, tuple)):
value = flatten_output_collection_property(name, value)
ref_outputs_dict.update(value)
......@@ -186,7 +196,7 @@ def validate_model_outputs(
# Check the shape and values match
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
ref_value = ref_outputs_dict[name].numpy()
ref_value = ref_outputs_dict[name].detach().numpy()
logger.info(f'\t- Validating ONNX Model output "{name}":')
# Shape
......@@ -197,7 +207,7 @@ def validate_model_outputs(
f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)"
)
else:
logger.info(f"\t\t-[✓] {ort_value.shape} matchs {ref_value.shape}")
logger.info(f"\t\t-[✓] {ort_value.shape} matches {ref_value.shape}")
# Values
if not np.allclose(ref_value, ort_value, atol=atol):
......
from functools import partial, reduce
from typing import Callable, Tuple
from .. import is_torch_available
from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig
from ..models.bert import BertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.longformer import LongformerOnnxConfig
from ..models.roberta import RobertaOnnxConfig
from ..models.t5 import T5OnnxConfig
from ..models.xlm_roberta import XLMRobertaOnnxConfig
if is_torch_available():
from transformers import PreTrainedModel
from transformers.models.auto import (
AutoModel,
AutoModelForCausalLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)
def supported_features_mapping(*supported_features, onnx_config_cls=None):
"""Generates the mapping between supported features and their corresponding OnnxConfig."""
if onnx_config_cls is None:
raise ValueError("A OnnxConfig class must be provided")
mapping = {}
for feature in supported_features:
if "-with-past" in feature:
task = feature.replace("-with-past", "")
mapping[feature] = partial(onnx_config_cls.with_past, task=task)
else:
mapping[feature] = partial(onnx_config_cls.from_model_config, task=feature)
return mapping
class FeaturesManager:
_TASKS_TO_AUTOMODELS = {
"default": AutoModel,
"causal-lm": AutoModelForCausalLM,
"sequence-classification": AutoModelForSequenceClassification,
"token-classification": AutoModelForTokenClassification,
"multiple-choice": AutoModelForMultipleChoice,
"question-answering": AutoModelForQuestionAnswering,
}
# Set of model topologies we support associated to the features supported by each topology and the factory
_SUPPORTED_MODEL_KIND = {
"albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig),
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
"roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig),
"t5": supported_features_mapping("default", onnx_config_cls=T5OnnxConfig),
"xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig),
"gpt-neo": supported_features_mapping(
"default",
"causal-lm",
"sequence-classification",
"default-with-past",
"causal-lm-with-past",
"sequence-classification-with-past",
onnx_config_cls=GPTNeoOnnxConfig,
),
}
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values())))
@staticmethod
def feature_to_task(feature: str) -> str:
return feature.replace("-with-past", "")
@staticmethod
def get_model_from_feature(feature: str, model: str):
"""
Attempt to retrieve a model from a model's name and the feature to be enabled.
Args:
feature: The feature required
model: The name of the model to export
Returns:
"""
task = FeaturesManager.feature_to_task(feature)
if task not in FeaturesManager._TASKS_TO_AUTOMODELS:
raise KeyError(
f"Unknown task: {feature}."
f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
)
return FeaturesManager._TASKS_TO_AUTOMODELS[task].from_pretrained(model)
@staticmethod
def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features
Args:
model: The model to export
feature: The name of the feature to check if it is avaiable
Returns:
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
"""
model_type = model.config.model_type.replace("_", "-")
model_name = getattr(model, "name", "")
model_name = f"({model_name})" if model_name else ""
if model_type not in FeaturesManager._SUPPORTED_MODEL_KIND:
raise KeyError(
f"{model.config.model_type} ({model_name}) is not supported yet. "
f"Only {FeaturesManager._SUPPORTED_MODEL_KIND} are supported. "
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
)
# Look for the features
model_features = FeaturesManager._SUPPORTED_MODEL_KIND[model_type]
if feature not in model_features:
raise ValueError(
f"{model.config.model_type} doesn't support feature {feature}. "
f"Supported values are: {list(model_features.keys())}"
)
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_KIND[model_type][feature]
......@@ -9,6 +9,7 @@ from transformers import ( # LongformerConfig,; T5Config,
BartConfig,
DistilBertConfig,
GPT2Config,
GPTNeoConfig,
RobertaConfig,
XLMRobertaConfig,
is_torch_available,
......@@ -20,6 +21,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig
# from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
# from transformers.models.t5 import T5OnnxConfig
......@@ -151,7 +153,8 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
with self.subTest(name):
self.assertFalse(
OnnxConfigWithPast.default(config()).use_past, "OnnxConfigWithPast.default() should not use_past"
OnnxConfigWithPast.from_model_config(config()).use_past,
"OnnxConfigWithPast.from_model_config() should not use_past",
)
self.assertTrue(
......@@ -167,7 +170,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
with self.subTest(name):
# without past
onnx_config_default = OnnxConfigWithPast.default(config())
onnx_config_default = OnnxConfigWithPast.from_model_config(config())
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
self.assertFalse(
......@@ -190,6 +193,7 @@ if is_torch_available():
BertModel,
DistilBertModel,
GPT2Model,
GPTNeoModel,
RobertaModel,
XLMRobertaModel,
)
......@@ -200,6 +204,7 @@ if is_torch_available():
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
("GPT-Neo", "EleutherAI/gpt-neo-125M", GPTNeoModel, GPTNeoConfig, GPTNeoOnnxConfig),
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment