Unverified Commit 675e2d16 authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Remove masked image modeling from BEIT ONNX export (#16980)



* Add masked image modelling to task mapping

* Refactor ONNX features to be listed alphabetically

* Add warning about BEiT masked image modeling
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 4bb1d0ec
......@@ -733,7 +733,10 @@ class BeitPooler(nn.Module):
@add_start_docstrings(
"Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).",
"""Beit Model transformer with a 'language' modeling head on top. BEiT does masked image modeling by predicting
visual tokens of a Vector-Quantize Variational Autoencoder (VQ-VAE), whereas other vision models like ViT and DeiT
predict RGB pixel values. As a result, this class is incompatible with [`AutoModelForMaskedImageModeling`], so you
will need to use [`BeitForMaskedImageModeling`] directly if you wish to do masked image modeling with BEiT.""",
BEIT_START_DOCSTRING,
)
class BeitForMaskedImageModeling(BeitPreTrainedModel):
......
......@@ -74,12 +74,11 @@ class OnnxConfig(ABC):
default_fixed_num_choices = 4
torch_onnx_minimum_version = version.parse("1.8")
_tasks_to_common_outputs = {
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
"image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"masked-im": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
"question-answering": OrderedDict(
{
......@@ -87,7 +86,9 @@ class OnnxConfig(ABC):
"end_logits": {0: "batch", 1: "sequence"},
}
),
"image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
}
def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None):
......
......@@ -142,17 +142,8 @@ class FeaturesManager:
"question-answering",
onnx_config_cls=BartOnnxConfig,
),
"mbart": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
"sequence-classification",
"question-answering",
onnx_config_cls=MBartOnnxConfig,
),
# BEiT cannot be used with the masked image modeling autoclass, so this feature is excluded here
"beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig),
"bert": supported_features_mapping(
"default",
"masked-lm",
......@@ -173,14 +164,23 @@ class FeaturesManager:
"question-answering",
onnx_config_cls=BigBirdOnnxConfig,
),
"ibert": supported_features_mapping(
"blenderbot": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=IBertOnnxConfig,
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls=BlenderbotOnnxConfig,
),
"blenderbot-small": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls=BlenderbotSmallOnnxConfig,
),
"camembert": supported_features_mapping(
"default",
......@@ -201,38 +201,28 @@ class FeaturesManager:
"question-answering",
onnx_config_cls=ConvBertOnnxConfig,
),
"distilbert": supported_features_mapping(
"data2vec-text": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=DistilBertOnnxConfig,
onnx_config_cls=Data2VecTextOnnxConfig,
),
"flaubert": supported_features_mapping(
"deit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=DeiTOnnxConfig
),
"distilbert": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=FlaubertOnnxConfig,
),
"marian": supported_features_mapping(
"default",
"default-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
"causal-lm",
"causal-lm-with-past",
onnx_config_cls=MarianOnnxConfig,
),
"m2m-100": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
onnx_config_cls=DistilBertOnnxConfig,
),
"roberta": supported_features_mapping(
"electra": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
......@@ -240,12 +230,9 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=RobertaOnnxConfig,
),
"t5": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
onnx_config_cls=ElectraOnnxConfig,
),
"xlm-roberta": supported_features_mapping(
"flaubert": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
......@@ -253,7 +240,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=XLMRobertaOnnxConfig,
onnx_config_cls=FlaubertOnnxConfig,
),
"gpt2": supported_features_mapping(
"default",
......@@ -281,58 +268,54 @@ class FeaturesManager:
"sequence-classification",
onnx_config_cls=GPTNeoOnnxConfig,
),
"layoutlm": supported_features_mapping(
"ibert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
onnx_config_cls=LayoutLMOnnxConfig,
"question-answering",
onnx_config_cls=IBertOnnxConfig,
),
"electra": supported_features_mapping(
"layoutlm": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=ElectraOnnxConfig,
),
"vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
),
"beit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=BeitOnnxConfig
),
"deit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=DeiTOnnxConfig
onnx_config_cls=LayoutLMOnnxConfig,
),
"blenderbot": supported_features_mapping(
"marian": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls=BlenderbotOnnxConfig,
"causal-lm",
"causal-lm-with-past",
onnx_config_cls=MarianOnnxConfig,
),
"blenderbot-small": supported_features_mapping(
"mbart": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls=BlenderbotSmallOnnxConfig,
"sequence-classification",
"question-answering",
onnx_config_cls=MBartOnnxConfig,
),
"data2vec-text": supported_features_mapping(
"m2m-100": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
),
"roberta": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=Data2VecTextOnnxConfig,
onnx_config_cls=RobertaOnnxConfig,
),
"roformer": supported_features_mapping(
"default",
......@@ -345,6 +328,22 @@ class FeaturesManager:
"token-classification",
onnx_config_cls=RoFormerOnnxConfig,
),
"t5": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
),
"vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
),
"xlm-roberta": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=XLMRobertaOnnxConfig,
),
}
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment