Unverified Commit 77de8d6c authored by Ella Charlaix's avatar Ella Charlaix Committed by GitHub
Browse files

Add onnx export of models with a multiple choice classification head (#16758)

* Add export of models with a multiple-choice classification head
parent b74a9553
...@@ -159,10 +159,14 @@ class AlbertConfig(PretrainedConfig): ...@@ -159,10 +159,14 @@ class AlbertConfig(PretrainedConfig):
class AlbertOnnxConfig(OnnxConfig): class AlbertOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
("token_type_ids", {0: "batch", 1: "sequence"}), ("token_type_ids", dynamic_axis),
] ]
) )
...@@ -160,10 +160,14 @@ class BertConfig(PretrainedConfig): ...@@ -160,10 +160,14 @@ class BertConfig(PretrainedConfig):
class BertOnnxConfig(OnnxConfig): class BertOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
("token_type_ids", {0: "batch", 1: "sequence"}), ("token_type_ids", dynamic_axis),
] ]
) )
...@@ -168,9 +168,13 @@ class BigBirdConfig(PretrainedConfig): ...@@ -168,9 +168,13 @@ class BigBirdConfig(PretrainedConfig):
class BigBirdOnnxConfig(OnnxConfig): class BigBirdOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )
...@@ -44,9 +44,13 @@ class CamembertConfig(RobertaConfig): ...@@ -44,9 +44,13 @@ class CamembertConfig(RobertaConfig):
class CamembertOnnxConfig(OnnxConfig): class CamembertOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )
...@@ -139,9 +139,13 @@ class Data2VecTextConfig(PretrainedConfig): ...@@ -139,9 +139,13 @@ class Data2VecTextConfig(PretrainedConfig):
class Data2VecTextOnnxConfig(OnnxConfig): class Data2VecTextOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )
...@@ -134,9 +134,13 @@ class DistilBertConfig(PretrainedConfig): ...@@ -134,9 +134,13 @@ class DistilBertConfig(PretrainedConfig):
class DistilBertOnnxConfig(OnnxConfig): class DistilBertOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )
...@@ -179,10 +179,14 @@ class ElectraConfig(PretrainedConfig): ...@@ -179,10 +179,14 @@ class ElectraConfig(PretrainedConfig):
class ElectraOnnxConfig(OnnxConfig): class ElectraOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
("token_type_ids", {0: "batch", 1: "sequence"}), ("token_type_ids", dynamic_axis),
] ]
) )
...@@ -146,9 +146,13 @@ class FlaubertConfig(XLMConfig): ...@@ -146,9 +146,13 @@ class FlaubertConfig(XLMConfig):
class FlaubertOnnxConfig(OnnxConfig): class FlaubertOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )
...@@ -234,7 +234,7 @@ class GPT2OnnxConfig(OnnxConfigWithPast): ...@@ -234,7 +234,7 @@ class GPT2OnnxConfig(OnnxConfigWithPast):
framework: Optional[TensorType] = None, framework: Optional[TensorType] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
) )
# We need to order the input in the way they appears in the forward() # We need to order the input in the way they appears in the forward()
......
...@@ -233,7 +233,7 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): ...@@ -233,7 +233,7 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
) )
# We need to order the input in the way they appears in the forward() # We need to order the input in the way they appears in the forward()
......
...@@ -183,7 +183,7 @@ class GPTJOnnxConfig(OnnxConfigWithPast): ...@@ -183,7 +183,7 @@ class GPTJOnnxConfig(OnnxConfigWithPast):
framework: Optional[TensorType] = None, framework: Optional[TensorType] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
) )
# We need to order the input in the way they appears in the forward() # We need to order the input in the way they appears in the forward()
......
...@@ -131,9 +131,13 @@ class IBertConfig(PretrainedConfig): ...@@ -131,9 +131,13 @@ class IBertConfig(PretrainedConfig):
class IBertOnnxConfig(OnnxConfig): class IBertOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )
...@@ -171,7 +171,9 @@ class LayoutLMOnnxConfig(OnnxConfig): ...@@ -171,7 +171,9 @@ class LayoutLMOnnxConfig(OnnxConfig):
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
""" """
input_dict = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) input_dict = super().generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
# Generate a dummy bbox # Generate a dummy bbox
box = [48, 84, 73, 128] box = [48, 84, 73, 128]
......
...@@ -70,9 +70,13 @@ class RobertaConfig(BertConfig): ...@@ -70,9 +70,13 @@ class RobertaConfig(BertConfig):
class RobertaOnnxConfig(OnnxConfig): class RobertaOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )
...@@ -47,9 +47,13 @@ class XLMRobertaConfig(RobertaConfig): ...@@ -47,9 +47,13 @@ class XLMRobertaConfig(RobertaConfig):
class XLMRobertaOnnxConfig(OnnxConfig): class XLMRobertaOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )
...@@ -143,9 +143,13 @@ class XLMRobertaXLConfig(PretrainedConfig): ...@@ -143,9 +143,13 @@ class XLMRobertaXLConfig(PretrainedConfig):
class XLMRobertaXLOnnxConfig(OnnxConfig): class XLMRobertaXLOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )
...@@ -71,6 +71,7 @@ class OnnxConfig(ABC): ...@@ -71,6 +71,7 @@ class OnnxConfig(ABC):
default_fixed_batch = 2 default_fixed_batch = 2
default_fixed_sequence = 8 default_fixed_sequence = 8
default_fixed_num_choices = 4
torch_onnx_minimum_version = version.parse("1.8") torch_onnx_minimum_version = version.parse("1.8")
_tasks_to_common_outputs = { _tasks_to_common_outputs = {
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), "default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
...@@ -174,6 +175,16 @@ class OnnxConfig(ABC): ...@@ -174,6 +175,16 @@ class OnnxConfig(ABC):
""" """
return OnnxConfig.default_fixed_sequence return OnnxConfig.default_fixed_sequence
@property
def default_num_choices(self) -> int:
"""
The default number of choices to use if no other indication
Returns:
Integer > 0
"""
return OnnxConfig.default_fixed_num_choices
@property @property
def default_onnx_opset(self) -> int: def default_onnx_opset(self) -> int:
""" """
...@@ -240,6 +251,7 @@ class OnnxConfig(ABC): ...@@ -240,6 +251,7 @@ class OnnxConfig(ABC):
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"], preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
batch_size: int = -1, batch_size: int = -1,
seq_length: int = -1, seq_length: int = -1,
num_choices: int = -1,
is_pair: bool = False, is_pair: bool = False,
framework: Optional[TensorType] = None, framework: Optional[TensorType] = None,
num_channels: int = 3, num_channels: int = 3,
...@@ -255,6 +267,8 @@ class OnnxConfig(ABC): ...@@ -255,6 +267,8 @@ class OnnxConfig(ABC):
The preprocessor associated with this model configuration. The preprocessor associated with this model configuration.
batch_size (`int`, *optional*, defaults to -1): batch_size (`int`, *optional*, defaults to -1):
The batch size to export the model for (-1 means dynamic axis). The batch size to export the model for (-1 means dynamic axis).
num_choices (`int`, *optional*, defaults to -1):
The number of candidate answers provided for multiple choice task (-1 means dynamic axis).
seq_length (`int`, *optional*, defaults to -1): seq_length (`int`, *optional*, defaults to -1):
The sequence length to export the model for (-1 means dynamic axis). The sequence length to export the model for (-1 means dynamic axis).
is_pair (`bool`, *optional*, defaults to `False`): is_pair (`bool`, *optional*, defaults to `False`):
...@@ -295,6 +309,19 @@ class OnnxConfig(ABC): ...@@ -295,6 +309,19 @@ class OnnxConfig(ABC):
) )
# Generate dummy inputs according to compute batch and sequence # Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size
if self.task == "multiple-choice":
# If dynamic axis (-1) we forward with a fixed dimension of 4 candidate answers to avoid optimizations
# made by ONNX
num_choices = compute_effective_axis_dimension(
num_choices, fixed_dimension=OnnxConfig.default_fixed_num_choices, num_token_to_add=0
)
dummy_input = dummy_input * num_choices
# The shape of the tokenized inputs values is [batch_size * num_choices, seq_length]
tokenized_input = preprocessor(dummy_input, text_pair=dummy_input)
# Unflatten the tokenized inputs values expanding it to the shape [batch_size, num_choices, seq_length]
for k, v in tokenized_input.items():
tokenized_input[k] = [v[i : i + num_choices] for i in range(0, len(v), num_choices)]
return dict(tokenized_input.convert_to_tensors(tensor_type=framework))
return dict(preprocessor(dummy_input, return_tensors=framework)) return dict(preprocessor(dummy_input, return_tensors=framework))
elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values": elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values":
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
...@@ -408,7 +435,9 @@ class OnnxConfigWithPast(OnnxConfig, ABC): ...@@ -408,7 +435,9 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
# TODO: should we set seq_length = 1 when self.use_past = True? # TODO: should we set seq_length = 1 when self.use_past = True?
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) common_inputs = super().generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
if self.use_past: if self.use_past:
if not is_torch_available(): if not is_torch_available():
...@@ -527,13 +556,13 @@ class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast): ...@@ -527,13 +556,13 @@ class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast):
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
) )
# Generate decoder inputs # Generate decoder inputs
decoder_seq_length = seq_length if not self.use_past else 1 decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, decoder_seq_length, is_pair, framework tokenizer, batch_size=batch_size, seq_length=decoder_seq_length, is_pair=is_pair, framework=framework
) )
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs) common_inputs = dict(**encoder_inputs, **decoder_inputs)
......
...@@ -10,6 +10,7 @@ from ..models.big_bird import BigBirdOnnxConfig ...@@ -10,6 +10,7 @@ from ..models.big_bird import BigBirdOnnxConfig
from ..models.blenderbot import BlenderbotOnnxConfig from ..models.blenderbot import BlenderbotOnnxConfig
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
from ..models.camembert import CamembertOnnxConfig from ..models.camembert import CamembertOnnxConfig
from ..models.data2vec import Data2VecTextOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig from ..models.electra import ElectraOnnxConfig
from ..models.flaubert import FlaubertOnnxConfig from ..models.flaubert import FlaubertOnnxConfig
...@@ -120,7 +121,7 @@ class FeaturesManager: ...@@ -120,7 +121,7 @@ class FeaturesManager:
"default", "default",
"masked-lm", "masked-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=AlbertOnnxConfig, onnx_config_cls=AlbertOnnxConfig,
...@@ -152,7 +153,7 @@ class FeaturesManager: ...@@ -152,7 +153,7 @@ class FeaturesManager:
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=BertOnnxConfig, onnx_config_cls=BertOnnxConfig,
...@@ -162,6 +163,7 @@ class FeaturesManager: ...@@ -162,6 +163,7 @@ class FeaturesManager:
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
"multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=BigBirdOnnxConfig, onnx_config_cls=BigBirdOnnxConfig,
...@@ -170,7 +172,7 @@ class FeaturesManager: ...@@ -170,7 +172,7 @@ class FeaturesManager:
"default", "default",
"masked-lm", "masked-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=IBertOnnxConfig, onnx_config_cls=IBertOnnxConfig,
...@@ -180,7 +182,7 @@ class FeaturesManager: ...@@ -180,7 +182,7 @@ class FeaturesManager:
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=CamembertOnnxConfig, onnx_config_cls=CamembertOnnxConfig,
...@@ -189,7 +191,7 @@ class FeaturesManager: ...@@ -189,7 +191,7 @@ class FeaturesManager:
"default", "default",
"masked-lm", "masked-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=DistilBertOnnxConfig, onnx_config_cls=DistilBertOnnxConfig,
...@@ -199,6 +201,7 @@ class FeaturesManager: ...@@ -199,6 +201,7 @@ class FeaturesManager:
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
"multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=FlaubertOnnxConfig, onnx_config_cls=FlaubertOnnxConfig,
...@@ -220,7 +223,7 @@ class FeaturesManager: ...@@ -220,7 +223,7 @@ class FeaturesManager:
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=RobertaOnnxConfig, onnx_config_cls=RobertaOnnxConfig,
...@@ -233,7 +236,7 @@ class FeaturesManager: ...@@ -233,7 +236,7 @@ class FeaturesManager:
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=XLMRobertaOnnxConfig, onnx_config_cls=XLMRobertaOnnxConfig,
...@@ -276,6 +279,7 @@ class FeaturesManager: ...@@ -276,6 +279,7 @@ class FeaturesManager:
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
"multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=ElectraOnnxConfig, onnx_config_cls=ElectraOnnxConfig,
...@@ -300,6 +304,15 @@ class FeaturesManager: ...@@ -300,6 +304,15 @@ class FeaturesManager:
"seq2seq-lm-with-past", "seq2seq-lm-with-past",
onnx_config_cls=BlenderbotSmallOnnxConfig, onnx_config_cls=BlenderbotSmallOnnxConfig,
), ),
"data2vec-text": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=Data2VecTextOnnxConfig,
),
} }
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
......
...@@ -182,6 +182,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -182,6 +182,7 @@ PYTORCH_EXPORT_MODELS = {
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"), ("vit", "google/vit-base-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"), ("beit", "microsoft/beit-base-patch16-224"),
("data2vec-text", "facebook/data2vec-text-base"),
} }
PYTORCH_EXPORT_WITH_PAST_MODELS = { PYTORCH_EXPORT_WITH_PAST_MODELS = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment