"test/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "9877f54491a7081266207e1a999dd47bc2bba17e"
Unverified Commit a97f3150 authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Add ONNX support for Blenderbot and BlenderbotSmall (#15875)

* Add ONNX support for Blenderbot

* Add BlenderbotSmall ONNX configuration

* Update serialization table
parent b473617d
...@@ -48,6 +48,8 @@ Ready-made configurations include the following architectures: ...@@ -48,6 +48,8 @@ Ready-made configurations include the following architectures:
- ALBERT - ALBERT
- BART - BART
- BERT - BERT
- Blenderbot
- BlenderbotSmall
- CamemBERT - CamemBERT
- Data2VecText - Data2VecText
- DistilBERT - DistilBERT
......
...@@ -22,7 +22,11 @@ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokeniz ...@@ -22,7 +22,11 @@ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokeniz
_import_structure = { _import_structure = {
"configuration_blenderbot": ["BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotConfig"], "configuration_blenderbot": [
"BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"BlenderbotConfig",
"BlenderbotOnnxConfig",
],
"tokenization_blenderbot": ["BlenderbotTokenizer"], "tokenization_blenderbot": ["BlenderbotTokenizer"],
} }
...@@ -56,7 +60,11 @@ if is_flax_available(): ...@@ -56,7 +60,11 @@ if is_flax_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig from .configuration_blenderbot import (
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BlenderbotConfig,
BlenderbotOnnxConfig,
)
from .tokenization_blenderbot import BlenderbotTokenizer from .tokenization_blenderbot import BlenderbotTokenizer
if is_tokenizers_available(): if is_tokenizers_available():
......
...@@ -14,7 +14,14 @@ ...@@ -14,7 +14,14 @@
# limitations under the License. # limitations under the License.
""" Blenderbot model configuration""" """ Blenderbot model configuration"""
from collections import OrderedDict
from typing import Any, Mapping, Optional
from ... import PreTrainedTokenizer
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import TensorType, is_torch_available
from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from ...onnx.utils import compute_effective_axis_dimension
from ...utils import logging from ...utils import logging
...@@ -164,3 +171,229 @@ class BlenderbotConfig(PretrainedConfig): ...@@ -164,3 +171,229 @@ class BlenderbotConfig(PretrainedConfig):
forced_eos_token_id=forced_eos_token_id, forced_eos_token_id=forced_eos_token_id,
**kwargs, **kwargs,
) )
class BlenderbotOnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
elif self.task == "causal-lm":
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
_, num_decoder_layers = self.num_layers
for i in range(num_decoder_layers):
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
else:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
]
)
return common_inputs
@property
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.outputs
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_outputs = super().outputs
else:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
return common_outputs
def _generate_dummy_inputs_for_default_and_seq2seq_lm(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
tokenizer, batch_size, seq_length, is_pair, framework
)
# Generate decoder inputs
decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
tokenizer, batch_size, decoder_seq_length, is_pair, framework
)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, encoder_seq_length = common_inputs["input_ids"].shape
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
encoder_shape = (
batch,
num_encoder_attention_heads,
encoder_seq_length,
self._config.hidden_size // num_encoder_attention_heads,
)
decoder_past_length = decoder_seq_length
decoder_shape = (
batch,
num_decoder_attention_heads,
decoder_past_length,
self._config.hidden_size // num_decoder_attention_heads,
)
common_inputs["decoder_attention_mask"] = torch.cat(
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
)
common_inputs["past_key_values"] = []
_, num_decoder_layers = self.num_layers
for _ in range(num_decoder_layers):
common_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
return common_inputs
def _generate_dummy_inputs_for_causal_lm(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
tokenizer, batch_size, seq_length, is_pair, framework
)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, seqlen = common_inputs["input_ids"].shape
past_key_values_length = seqlen
_, num_decoder_layers = self.num_layers
num_encoder_attention_heads, _ = self.num_attention_heads
past_shape = (
batch,
num_encoder_attention_heads,
past_key_values_length,
self._config.hidden_size // num_encoder_attention_heads,
)
common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_decoder_layers)
]
return common_inputs
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
# Copied from OnnxConfig.generate_dummy_inputs
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
)
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
)
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
return common_inputs
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.generate_dummy_inputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
if self.task in ["default", "seq2seq-lm"]:
common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
elif self.task == "causal-lm":
common_inputs = self._generate_dummy_inputs_for_causal_lm(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
else:
common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
return common_inputs
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._flatten_past_key_values_
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
if self.task in ["default", "seq2seq-lm"]:
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
else:
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
name = "past_key_values" if direction == "inputs" else "present"
_, num_decoder_layers = self.num_layers
encoder_sequence = "past_encoder_sequence"
decoder_sequence = "past_decoder_sequence" if direction == "inputs" else "past_decoder_sequence + sequence"
for i in range(num_decoder_layers):
inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch", 2: decoder_sequence}
inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch", 2: decoder_sequence}
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch", 2: encoder_sequence}
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch", 2: encoder_sequence}
...@@ -21,7 +21,11 @@ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokeniz ...@@ -21,7 +21,11 @@ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokeniz
_import_structure = { _import_structure = {
"configuration_blenderbot_small": ["BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotSmallConfig"], "configuration_blenderbot_small": [
"BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
"BlenderbotSmallConfig",
"BlenderbotSmallOnnxConfig",
],
"tokenization_blenderbot_small": ["BlenderbotSmallTokenizer"], "tokenization_blenderbot_small": ["BlenderbotSmallTokenizer"],
} }
...@@ -52,7 +56,11 @@ if is_flax_available(): ...@@ -52,7 +56,11 @@ if is_flax_available():
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_blenderbot_small import BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotSmallConfig from .configuration_blenderbot_small import (
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
BlenderbotSmallConfig,
BlenderbotSmallOnnxConfig,
)
from .tokenization_blenderbot_small import BlenderbotSmallTokenizer from .tokenization_blenderbot_small import BlenderbotSmallTokenizer
if is_tokenizers_available(): if is_tokenizers_available():
......
...@@ -14,7 +14,14 @@ ...@@ -14,7 +14,14 @@
# limitations under the License. # limitations under the License.
""" BlenderbotSmall model configuration""" """ BlenderbotSmall model configuration"""
from collections import OrderedDict
from typing import Any, Mapping, Optional
from ... import PreTrainedTokenizer
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import TensorType, is_torch_available
from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from ...onnx.utils import compute_effective_axis_dimension
from ...utils import logging from ...utils import logging
...@@ -162,3 +169,226 @@ class BlenderbotSmallConfig(PretrainedConfig): ...@@ -162,3 +169,226 @@ class BlenderbotSmallConfig(PretrainedConfig):
forced_eos_token_id=forced_eos_token_id, forced_eos_token_id=forced_eos_token_id,
**kwargs, **kwargs,
) )
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig
class BlenderbotSmallOnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
elif self.task == "causal-lm":
# TODO: figure this case out.
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
else:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
]
)
return common_inputs
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_outputs = super().outputs
else:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
return common_outputs
def _generate_dummy_inputs_for_default_and_seq2seq_lm(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
tokenizer, batch_size, seq_length, is_pair, framework
)
# Generate decoder inputs
decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
tokenizer, batch_size, decoder_seq_length, is_pair, framework
)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, encoder_seq_length = common_inputs["input_ids"].shape
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
encoder_shape = (
batch,
num_encoder_attention_heads,
encoder_seq_length,
self._config.hidden_size // num_encoder_attention_heads,
)
decoder_past_length = decoder_seq_length + 3
decoder_shape = (
batch,
num_decoder_attention_heads,
decoder_past_length,
self._config.hidden_size // num_decoder_attention_heads,
)
common_inputs["decoder_attention_mask"] = torch.cat(
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
)
common_inputs["past_key_values"] = []
# If the number of encoder and decoder layers are present in the model configuration, both are considered
num_encoder_layers, num_decoder_layers = self.num_layers
min_num_layers = min(num_encoder_layers, num_decoder_layers)
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
for _ in range(min_num_layers):
common_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
# TODO: test this.
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
for _ in range(min_num_layers, max_num_layers):
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
return common_inputs
def _generate_dummy_inputs_for_causal_lm(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
tokenizer, batch_size, seq_length, is_pair, framework
)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
num_encoder_layers, _ = self.num_layers
num_encoder_attention_heads, _ = self.num_attention_heads
past_shape = (
batch,
num_encoder_attention_heads,
past_key_values_length,
self._config.hidden_size // num_encoder_attention_heads,
)
common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
]
return common_inputs
def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
# Copied from OnnxConfig.generate_dummy_inputs
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
)
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
)
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
return common_inputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
if self.task in ["default", "seq2seq-lm"]:
common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
elif self.task == "causal-lm":
common_inputs = self._generate_dummy_inputs_for_causal_lm(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
else:
common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
return common_inputs
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
if self.task in ["default", "seq2seq-lm"]:
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
else:
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
...@@ -5,6 +5,8 @@ from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_avail ...@@ -5,6 +5,8 @@ from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_avail
from ..models.albert import AlbertOnnxConfig from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig from ..models.bart import BartOnnxConfig
from ..models.bert import BertOnnxConfig from ..models.bert import BertOnnxConfig
from ..models.blenderbot import BlenderbotOnnxConfig
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
from ..models.camembert import CamembertOnnxConfig from ..models.camembert import CamembertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig from ..models.electra import ElectraOnnxConfig
...@@ -268,6 +270,24 @@ class FeaturesManager: ...@@ -268,6 +270,24 @@ class FeaturesManager:
onnx_config_cls=ElectraOnnxConfig, onnx_config_cls=ElectraOnnxConfig,
), ),
"vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig), "vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig),
"blenderbot": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls=BlenderbotOnnxConfig,
),
"blenderbot-small": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls=BlenderbotSmallOnnxConfig,
),
} }
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
......
...@@ -194,6 +194,8 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = { ...@@ -194,6 +194,8 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("t5", "t5-small"), ("t5", "t5-small"),
("marian", "Helsinki-NLP/opus-mt-en-de"), ("marian", "Helsinki-NLP/opus-mt-en-de"),
("m2m-100", "facebook/m2m100_418M"), ("m2m-100", "facebook/m2m100_418M"),
("blenderbot-small", "facebook/blenderbot_small-90M"),
("blenderbot", "facebook/blenderbot-400M-distill"),
} }
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations. # TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment