Unverified Commit 675e2d16 authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Remove masked image modeling from BEIT ONNX export (#16980)



* Add masked image modelling to task mapping

* Refactor ONNX features to be listed alphabetically

* Add warning about BEiT masked image modeling
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 4bb1d0ec
...@@ -733,7 +733,10 @@ class BeitPooler(nn.Module): ...@@ -733,7 +733,10 @@ class BeitPooler(nn.Module):
@add_start_docstrings( @add_start_docstrings(
"Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).", """Beit Model transformer with a 'language' modeling head on top. BEiT does masked image modeling by predicting
visual tokens of a Vector-Quantize Variational Autoencoder (VQ-VAE), whereas other vision models like ViT and DeiT
predict RGB pixel values. As a result, this class is incompatible with [`AutoModelForMaskedImageModeling`], so you
will need to use [`BeitForMaskedImageModeling`] directly if you wish to do masked image modeling with BEiT.""",
BEIT_START_DOCSTRING, BEIT_START_DOCSTRING,
) )
class BeitForMaskedImageModeling(BeitPreTrainedModel): class BeitForMaskedImageModeling(BeitPreTrainedModel):
......
...@@ -74,12 +74,11 @@ class OnnxConfig(ABC): ...@@ -74,12 +74,11 @@ class OnnxConfig(ABC):
default_fixed_num_choices = 4 default_fixed_num_choices = 4
torch_onnx_minimum_version = version.parse("1.8") torch_onnx_minimum_version = version.parse("1.8")
_tasks_to_common_outputs = { _tasks_to_common_outputs = {
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), "default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
"image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"masked-im": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"multiple-choice": OrderedDict({"logits": {0: "batch"}}), "multiple-choice": OrderedDict({"logits": {0: "batch"}}),
"question-answering": OrderedDict( "question-answering": OrderedDict(
{ {
...@@ -87,7 +86,9 @@ class OnnxConfig(ABC): ...@@ -87,7 +86,9 @@ class OnnxConfig(ABC):
"end_logits": {0: "batch", 1: "sequence"}, "end_logits": {0: "batch", 1: "sequence"},
} }
), ),
"image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
} }
def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None): def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None):
......
...@@ -142,17 +142,8 @@ class FeaturesManager: ...@@ -142,17 +142,8 @@ class FeaturesManager:
"question-answering", "question-answering",
onnx_config_cls=BartOnnxConfig, onnx_config_cls=BartOnnxConfig,
), ),
"mbart": supported_features_mapping( # BEiT cannot be used with the masked image modeling autoclass, so this feature is excluded here
"default", "beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig),
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
"sequence-classification",
"question-answering",
onnx_config_cls=MBartOnnxConfig,
),
"bert": supported_features_mapping( "bert": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
...@@ -173,14 +164,23 @@ class FeaturesManager: ...@@ -173,14 +164,23 @@ class FeaturesManager:
"question-answering", "question-answering",
onnx_config_cls=BigBirdOnnxConfig, onnx_config_cls=BigBirdOnnxConfig,
), ),
"ibert": supported_features_mapping( "blenderbot": supported_features_mapping(
"default", "default",
"masked-lm", "default-with-past",
"sequence-classification", "causal-lm",
"multiple-choice", "causal-lm-with-past",
"token-classification", "seq2seq-lm",
"question-answering", "seq2seq-lm-with-past",
onnx_config_cls=IBertOnnxConfig, onnx_config_cls=BlenderbotOnnxConfig,
),
"blenderbot-small": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls=BlenderbotSmallOnnxConfig,
), ),
"camembert": supported_features_mapping( "camembert": supported_features_mapping(
"default", "default",
...@@ -201,38 +201,28 @@ class FeaturesManager: ...@@ -201,38 +201,28 @@ class FeaturesManager:
"question-answering", "question-answering",
onnx_config_cls=ConvBertOnnxConfig, onnx_config_cls=ConvBertOnnxConfig,
), ),
"distilbert": supported_features_mapping( "data2vec-text": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
"sequence-classification", "sequence-classification",
"multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=DistilBertOnnxConfig, onnx_config_cls=Data2VecTextOnnxConfig,
), ),
"flaubert": supported_features_mapping( "deit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=DeiTOnnxConfig
),
"distilbert": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
"causal-lm",
"sequence-classification", "sequence-classification",
"multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=FlaubertOnnxConfig, onnx_config_cls=DistilBertOnnxConfig,
),
"marian": supported_features_mapping(
"default",
"default-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
"causal-lm",
"causal-lm-with-past",
onnx_config_cls=MarianOnnxConfig,
),
"m2m-100": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
), ),
"roberta": supported_features_mapping( "electra": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
...@@ -240,12 +230,9 @@ class FeaturesManager: ...@@ -240,12 +230,9 @@ class FeaturesManager:
"multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=RobertaOnnxConfig, onnx_config_cls=ElectraOnnxConfig,
),
"t5": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
), ),
"xlm-roberta": supported_features_mapping( "flaubert": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
...@@ -253,7 +240,7 @@ class FeaturesManager: ...@@ -253,7 +240,7 @@ class FeaturesManager:
"multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=XLMRobertaOnnxConfig, onnx_config_cls=FlaubertOnnxConfig,
), ),
"gpt2": supported_features_mapping( "gpt2": supported_features_mapping(
"default", "default",
...@@ -281,58 +268,54 @@ class FeaturesManager: ...@@ -281,58 +268,54 @@ class FeaturesManager:
"sequence-classification", "sequence-classification",
onnx_config_cls=GPTNeoOnnxConfig, onnx_config_cls=GPTNeoOnnxConfig,
), ),
"layoutlm": supported_features_mapping( "ibert": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
"sequence-classification", "sequence-classification",
"multiple-choice",
"token-classification", "token-classification",
onnx_config_cls=LayoutLMOnnxConfig, "question-answering",
onnx_config_cls=IBertOnnxConfig,
), ),
"electra": supported_features_mapping( "layoutlm": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
"causal-lm",
"sequence-classification", "sequence-classification",
"multiple-choice",
"token-classification", "token-classification",
"question-answering", onnx_config_cls=LayoutLMOnnxConfig,
onnx_config_cls=ElectraOnnxConfig,
),
"vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
),
"beit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=BeitOnnxConfig
),
"deit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=DeiTOnnxConfig
), ),
"blenderbot": supported_features_mapping( "marian": supported_features_mapping(
"default", "default",
"default-with-past", "default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm", "seq2seq-lm",
"seq2seq-lm-with-past", "seq2seq-lm-with-past",
onnx_config_cls=BlenderbotOnnxConfig, "causal-lm",
"causal-lm-with-past",
onnx_config_cls=MarianOnnxConfig,
), ),
"blenderbot-small": supported_features_mapping( "mbart": supported_features_mapping(
"default", "default",
"default-with-past", "default-with-past",
"causal-lm", "causal-lm",
"causal-lm-with-past", "causal-lm-with-past",
"seq2seq-lm", "seq2seq-lm",
"seq2seq-lm-with-past", "seq2seq-lm-with-past",
onnx_config_cls=BlenderbotSmallOnnxConfig, "sequence-classification",
"question-answering",
onnx_config_cls=MBartOnnxConfig,
), ),
"data2vec-text": supported_features_mapping( "m2m-100": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
),
"roberta": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
"causal-lm",
"sequence-classification", "sequence-classification",
"multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=Data2VecTextOnnxConfig, onnx_config_cls=RobertaOnnxConfig,
), ),
"roformer": supported_features_mapping( "roformer": supported_features_mapping(
"default", "default",
...@@ -345,6 +328,22 @@ class FeaturesManager: ...@@ -345,6 +328,22 @@ class FeaturesManager:
"token-classification", "token-classification",
onnx_config_cls=RoFormerOnnxConfig, onnx_config_cls=RoFormerOnnxConfig,
), ),
"t5": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
),
"vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
),
"xlm-roberta": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=XLMRobertaOnnxConfig,
),
} }
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment