"docs/source/en/preprocessing.md" did not exist on "b9a768b3ffa80c4c19d024f9f42d5917e7d8109e"
Unverified Commit 8cb5ecd9 authored by Thomas Chaigneau's avatar Thomas Chaigneau Committed by GitHub
Browse files

Add mt5 onnx config (#18394)

* update features

* MT5OnnxConfig added with updated with tests and docs

* fix imports

* fix onnc_config_cls for mt5

Co-authored-by: Thomas Chaigneau <thomas.deeptools.ai>
parent fe785730
......@@ -79,6 +79,7 @@ Ready-made configurations include the following architectures:
- mBART
- MobileBERT
- MobileViT
- MT5
- OpenAI GPT-2
- Perceiver
- PLBart
......
......@@ -43,7 +43,7 @@ else:
MT5TokenizerFast = T5TokenizerFast
_import_structure = {"configuration_mt5": ["MT5Config"]}
_import_structure = {"configuration_mt5": ["MT5Config", "MT5OnnxConfig"]}
try:
if not is_torch_available():
......@@ -71,7 +71,7 @@ else:
if TYPE_CHECKING:
from .configuration_mt5 import MT5Config
from .configuration_mt5 import MT5Config, MT5OnnxConfig
try:
if not is_torch_available():
......
......@@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" mT5 model configuration"""
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxSeq2SeqConfigWithPast
from ...utils import logging
......@@ -143,3 +145,29 @@ class MT5Config(PretrainedConfig):
@property
def num_hidden_layers(self):
return self.num_layers
# Copied from transformers.models.t5.configuration_t5.T5OnnxConfig
class MT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = {
"input_ids": {0: "batch", 1: "encoder_sequence"},
"attention_mask": {0: "batch", 1: "encoder_sequence"},
}
if self.use_past:
common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
return common_inputs
@property
def default_onnx_opset(self) -> int:
return 13
......@@ -383,6 +383,13 @@ class FeaturesManager:
"image-classification",
onnx_config_cls="models.mobilevit.MobileViTOnnxConfig",
),
"mt5": supported_features_mapping(
"default",
"default-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls="models.mt5.MT5OnnxConfig",
),
"m2m-100": supported_features_mapping(
"default",
"default-with-past",
......
......@@ -224,6 +224,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("mbart", "sshleifer/tiny-mbart"),
("t5", "t5-small"),
("marian", "Helsinki-NLP/opus-mt-en-de"),
("mt5", "google/mt5-base"),
("m2m-100", "facebook/m2m100_418M"),
("blenderbot-small", "facebook/blenderbot_small-90M"),
("blenderbot", "facebook/blenderbot-400M-distill"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment