Unverified Commit 3f94170a authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

[WIP] Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleC… (#5614)

* Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleChoice} models and tests

* AutoModels


Tiny tweaks

* Style

* Final changes before merge

* Re-order for simpler review

* Final fixes

* Addressing @sgugger's comments

* Test MultipleChoice
parent 8a8ae276
...@@ -278,6 +278,7 @@ if is_torch_available(): ...@@ -278,6 +278,7 @@ if is_torch_available():
XLMForTokenClassification, XLMForTokenClassification,
XLMForQuestionAnswering, XLMForQuestionAnswering,
XLMForQuestionAnsweringSimple, XLMForQuestionAnsweringSimple,
XLMForMultipleChoice,
XLM_PRETRAINED_MODEL_ARCHIVE_LIST, XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_bart import ( from .modeling_bart import (
...@@ -356,6 +357,8 @@ if is_torch_available(): ...@@ -356,6 +357,8 @@ if is_torch_available():
FlaubertForTokenClassification, FlaubertForTokenClassification,
FlaubertForQuestionAnswering, FlaubertForQuestionAnswering,
FlaubertForQuestionAnsweringSimple, FlaubertForQuestionAnsweringSimple,
FlaubertForTokenClassification,
FlaubertForMultipleChoice,
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
......
...@@ -98,6 +98,7 @@ from .modeling_electra import ( ...@@ -98,6 +98,7 @@ from .modeling_electra import (
) )
from .modeling_encoder_decoder import EncoderDecoderModel from .modeling_encoder_decoder import EncoderDecoderModel
from .modeling_flaubert import ( from .modeling_flaubert import (
FlaubertForMultipleChoice,
FlaubertForQuestionAnsweringSimple, FlaubertForQuestionAnsweringSimple,
FlaubertForSequenceClassification, FlaubertForSequenceClassification,
FlaubertForTokenClassification, FlaubertForTokenClassification,
...@@ -142,6 +143,7 @@ from .modeling_roberta import ( ...@@ -142,6 +143,7 @@ from .modeling_roberta import (
from .modeling_t5 import T5ForConditionalGeneration, T5Model from .modeling_t5 import T5ForConditionalGeneration, T5Model
from .modeling_transfo_xl import TransfoXLLMHeadModel, TransfoXLModel from .modeling_transfo_xl import TransfoXLLMHeadModel, TransfoXLModel
from .modeling_xlm import ( from .modeling_xlm import (
XLMForMultipleChoice,
XLMForQuestionAnsweringSimple, XLMForQuestionAnsweringSimple,
XLMForSequenceClassification, XLMForSequenceClassification,
XLMForTokenClassification, XLMForTokenClassification,
...@@ -338,6 +340,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( ...@@ -338,6 +340,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
(XLNetConfig, XLNetForTokenClassification), (XLNetConfig, XLNetForTokenClassification),
(AlbertConfig, AlbertForTokenClassification), (AlbertConfig, AlbertForTokenClassification),
(ElectraConfig, ElectraForTokenClassification), (ElectraConfig, ElectraForTokenClassification),
(FlaubertConfig, FlaubertForTokenClassification),
] ]
) )
...@@ -353,6 +356,8 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( ...@@ -353,6 +356,8 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
(MobileBertConfig, MobileBertForMultipleChoice), (MobileBertConfig, MobileBertForMultipleChoice),
(XLNetConfig, XLNetForMultipleChoice), (XLNetConfig, XLNetForMultipleChoice),
(AlbertConfig, AlbertForMultipleChoice), (AlbertConfig, AlbertForMultipleChoice),
(XLMConfig, XLMForMultipleChoice),
(FlaubertConfig, FlaubertForMultipleChoice),
] ]
) )
......
...@@ -25,6 +25,7 @@ from .configuration_flaubert import FlaubertConfig ...@@ -25,6 +25,7 @@ from .configuration_flaubert import FlaubertConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_outputs import BaseModelOutput from .modeling_outputs import BaseModelOutput
from .modeling_xlm import ( from .modeling_xlm import (
XLMForMultipleChoice,
XLMForQuestionAnswering, XLMForQuestionAnswering,
XLMForQuestionAnsweringSimple, XLMForQuestionAnsweringSimple,
XLMForSequenceClassification, XLMForSequenceClassification,
...@@ -382,3 +383,22 @@ class FlaubertForQuestionAnswering(XLMForQuestionAnswering): ...@@ -382,3 +383,22 @@ class FlaubertForQuestionAnswering(XLMForQuestionAnswering):
super().__init__(config) super().__init__(config)
self.transformer = FlaubertModel(config) self.transformer = FlaubertModel(config)
self.init_weights() self.init_weights()
@add_start_docstrings(
"""Flaubert Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
FLAUBERT_START_DOCSTRING,
)
class FlaubertForMultipleChoice(XLMForMultipleChoice):
"""
This class overrides :class:`~transformers.XLMForMultipleChoice`. Please check the
superclass for the appropriate documentation alongside usage examples.
"""
config_class = FlaubertConfig
def __init__(self, config):
super().__init__(config)
self.transformer = FlaubertModel(config)
self.init_weights()
...@@ -22,7 +22,7 @@ import tensorflow as tf ...@@ -22,7 +22,7 @@ import tensorflow as tf
from .configuration_flaubert import FlaubertConfig from .configuration_flaubert import FlaubertConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_utils import keras_serializable, shape_list from .modeling_tf_utils import cast_bool_to_primitive, keras_serializable, shape_list
from .modeling_tf_xlm import ( from .modeling_tf_xlm import (
TFXLMForMultipleChoice, TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple, TFXLMForQuestionAnsweringSimple,
...@@ -30,6 +30,7 @@ from .modeling_tf_xlm import ( ...@@ -30,6 +30,7 @@ from .modeling_tf_xlm import (
TFXLMForTokenClassification, TFXLMForTokenClassification,
TFXLMMainLayer, TFXLMMainLayer,
TFXLMModel, TFXLMModel,
TFXLMPredLayer,
TFXLMWithLMHeadModel, TFXLMWithLMHeadModel,
get_masks, get_masks,
) )
...@@ -123,6 +124,8 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -123,6 +124,8 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.layerdrop = getattr(config, "layerdrop", 0.0) self.layerdrop = getattr(config, "layerdrop", 0.0)
self.pre_norm = getattr(config, "pre_norm", False) self.pre_norm = getattr(config, "pre_norm", False)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
def call( def call(
self, self,
...@@ -135,9 +138,9 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -135,9 +138,9 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
cache=None, cache=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
training=False, training=False,
output_attentions=False,
output_hidden_states=False,
): ):
# removed: src_enc=None, src_len=None # removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
...@@ -150,7 +153,9 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -150,7 +153,9 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
cache = inputs[6] if len(inputs) > 6 else cache cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs." output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
assert len(inputs) <= 11, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -161,10 +166,15 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -161,10 +166,15 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
cache = inputs.get("cache", cache) cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 9, "Too many inputs." output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 11, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -257,8 +267,11 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -257,8 +267,11 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
# self attention # self attention
if not self.pre_norm: if not self.pre_norm:
attn_outputs = self.attentions[i]([tensor, attn_mask, None, cache, head_mask[i]], training=training) attn_outputs = self.attentions[i](
[tensor, attn_mask, None, cache, head_mask[i], output_attentions], training=training
)
attn = attn_outputs[0] attn = attn_outputs[0]
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=training)
tensor = tensor + attn tensor = tensor + attn
...@@ -269,7 +282,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -269,7 +282,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
[tensor_normalized, attn_mask, None, cache, head_mask[i]], training=training [tensor_normalized, attn_mask, None, cache, head_mask[i]], training=training
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=training)
tensor = tensor + attn tensor = tensor + attn
...@@ -292,7 +305,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -292,7 +305,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
tensor = tensor * mask[..., tf.newaxis] tensor = tensor * mask[..., tf.newaxis]
# Add last hidden state # Add last hidden state
if output_hidden_states: if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# update cache length # update cache length
...@@ -303,9 +316,9 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -303,9 +316,9 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
# tensor = tensor.transpose(0, 1) # tensor = tensor.transpose(0, 1)
outputs = (tensor,) outputs = (tensor,)
if output_hidden_states: if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
outputs = outputs + (hidden_states,) outputs = outputs + (hidden_states,)
if output_attentions: if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
outputs = outputs + (attentions,) outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions) return outputs # outputs, (hidden_states), (attentions)
...@@ -321,6 +334,7 @@ class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel): ...@@ -321,6 +334,7 @@ class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.transformer = TFFlaubertMainLayer(config, name="transformer") self.transformer = TFFlaubertMainLayer(config, name="transformer")
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
@add_start_docstrings( @add_start_docstrings(
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import itertools import itertools
import logging import logging
import math import math
import warnings
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -827,6 +828,9 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -827,6 +828,9 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
self.transformer = TFXLMMainLayer(config, name="transformer") self.transformer = TFXLMMainLayer(config, name="transformer")
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary") self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
self.logits_proj = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
)
@property @property
def dummy_inputs(self): def dummy_inputs(self):
...@@ -835,7 +839,10 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -835,7 +839,10 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
Returns: Returns:
tf.Tensor with dummy inputs tf.Tensor with dummy inputs
""" """
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} return {
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
"langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
}
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlm-mlm-en-2048") @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlm-mlm-en-2048")
...@@ -892,7 +899,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -892,7 +899,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
labels = inputs[11] if len(inputs) > 11 else labels labels = inputs[11] if len(inputs) > 11 else labels
assert len(inputs) <= 11, "Too many inputs." assert len(inputs) <= 12, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -921,17 +928,31 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -921,17 +928,31 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, inputs_embeds.shape[-2], inputs_embeds.shape[-1]))
if inputs_embeds is not None
else None
)
if lengths is not None:
warnings.warn(
"The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the "
"attention mask instead.",
FutureWarning,
)
lengths = None
flat_inputs = [ flat_inputs = [
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
langs, flat_langs,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
lengths, lengths,
cache, cache,
head_mask, head_mask,
inputs_embeds, flat_inputs_embeds,
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
] ]
...@@ -939,6 +960,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -939,6 +960,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
transformer_outputs = self.transformer(flat_inputs, training=training) transformer_outputs = self.transformer(flat_inputs, training=training)
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import itertools import itertools
import logging import logging
import math import math
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -40,6 +41,7 @@ from .file_utils import ( ...@@ -40,6 +41,7 @@ from .file_utils import (
from .modeling_outputs import ( from .modeling_outputs import (
BaseModelOutput, BaseModelOutput,
MaskedLMOutput, MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput, QuestionAnsweringModelOutput,
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
...@@ -1122,3 +1124,105 @@ class XLMForTokenClassification(XLMPreTrainedModel): ...@@ -1122,3 +1124,105 @@ class XLMForTokenClassification(XLMPreTrainedModel):
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
) )
@add_start_docstrings(
"""XLM Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
XLM_START_DOCSTRING,
)
class XLMForMultipleChoice(XLMPreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = XLMModel(config)
self.sequence_summary = SequenceSummary(config)
self.logits_proj = nn.Linear(config.num_labels, 1)
self.init_weights()
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="xlm-mlm-en-2048",
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
langs = langs.view(-1, langs.size(-1)) if langs is not None else None
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)
if lengths is not None:
warnings.warn(
"The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the "
"attention mask instead.",
FutureWarning,
)
lengths = None
transformer_outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
logits = self.logits_proj(logits)
reshaped_logits = logits.view(-1, num_choices)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
if return_tuple:
output = (reshaped_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
...@@ -66,7 +66,7 @@ class ModelTesterMixin: ...@@ -66,7 +66,7 @@ class ModelTesterMixin:
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
return { return {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous() k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if isinstance(v, torch.Tensor) and v.ndim != 0 if isinstance(v, torch.Tensor) and v.ndim > 1
else v else v
for k, v in inputs_dict.items() for k, v in inputs_dict.items()
} }
......
...@@ -32,6 +32,7 @@ if is_torch_available(): ...@@ -32,6 +32,7 @@ if is_torch_available():
FlaubertForQuestionAnsweringSimple, FlaubertForQuestionAnsweringSimple,
FlaubertForSequenceClassification, FlaubertForSequenceClassification,
FlaubertForTokenClassification, FlaubertForTokenClassification,
FlaubertForMultipleChoice,
) )
from transformers.modeling_flaubert import FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.modeling_flaubert import FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST
...@@ -90,6 +91,7 @@ class FlaubertModelTester(object): ...@@ -90,6 +91,7 @@ class FlaubertModelTester(object):
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
is_impossible_labels = ids_tensor([self.batch_size], 2).float() is_impossible_labels = ids_tensor([self.batch_size], 2).float()
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = FlaubertConfig( config = FlaubertConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
...@@ -118,6 +120,7 @@ class FlaubertModelTester(object): ...@@ -118,6 +120,7 @@ class FlaubertModelTester(object):
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
) )
...@@ -133,6 +136,7 @@ class FlaubertModelTester(object): ...@@ -133,6 +136,7 @@ class FlaubertModelTester(object):
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
model = FlaubertModel(config=config) model = FlaubertModel(config=config)
...@@ -158,6 +162,7 @@ class FlaubertModelTester(object): ...@@ -158,6 +162,7 @@ class FlaubertModelTester(object):
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
model = FlaubertWithLMHeadModel(config) model = FlaubertWithLMHeadModel(config)
...@@ -183,6 +188,7 @@ class FlaubertModelTester(object): ...@@ -183,6 +188,7 @@ class FlaubertModelTester(object):
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
model = FlaubertForQuestionAnsweringSimple(config) model = FlaubertForQuestionAnsweringSimple(config)
...@@ -212,6 +218,7 @@ class FlaubertModelTester(object): ...@@ -212,6 +218,7 @@ class FlaubertModelTester(object):
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
model = FlaubertForQuestionAnswering(config) model = FlaubertForQuestionAnswering(config)
...@@ -278,6 +285,7 @@ class FlaubertModelTester(object): ...@@ -278,6 +285,7 @@ class FlaubertModelTester(object):
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
model = FlaubertForSequenceClassification(config) model = FlaubertForSequenceClassification(config)
...@@ -304,6 +312,7 @@ class FlaubertModelTester(object): ...@@ -304,6 +312,7 @@ class FlaubertModelTester(object):
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
config.num_labels = self.num_labels config.num_labels = self.num_labels
...@@ -319,6 +328,38 @@ class FlaubertModelTester(object): ...@@ -319,6 +328,38 @@ class FlaubertModelTester(object):
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_flaubert_multiple_choice(
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
):
config.num_choices = self.num_choices
model = FlaubertForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -329,6 +370,7 @@ class FlaubertModelTester(object): ...@@ -329,6 +370,7 @@ class FlaubertModelTester(object):
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
) = config_and_inputs ) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths} inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths}
...@@ -346,6 +388,7 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -346,6 +388,7 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
FlaubertForQuestionAnsweringSimple, FlaubertForQuestionAnsweringSimple,
FlaubertForSequenceClassification, FlaubertForSequenceClassification,
FlaubertForTokenClassification, FlaubertForTokenClassification,
FlaubertForMultipleChoice,
) )
if is_torch_available() if is_torch_available()
else () else ()
...@@ -382,6 +425,10 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -382,6 +425,10 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_flaubert_token_classif(*config_and_inputs) self.model_tester.create_and_check_flaubert_token_classif(*config_and_inputs)
def test_flaubert_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_flaubert_multiple_choice(*config_and_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -80,8 +80,8 @@ class TFModelTesterMixin: ...@@ -80,8 +80,8 @@ class TFModelTesterMixin:
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
inputs_dict = { inputs_dict = {
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices, 1)) k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
if isinstance(v, tf.Tensor) and v.ndim != 0 if isinstance(v, tf.Tensor) and v.ndim > 0
else v else v
for k, v in inputs_dict.items() for k, v in inputs_dict.items()
} }
......
...@@ -18,11 +18,340 @@ import unittest ...@@ -18,11 +18,340 @@ import unittest
from transformers import is_tf_available from transformers import is_tf_available
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from transformers import TFFlaubertModel
from transformers import (
FlaubertConfig,
TFFlaubertModel,
TFFlaubertWithLMHeadModel,
TFFlaubertForSequenceClassification,
TFFlaubertForQuestionAnsweringSimple,
TFFlaubertForTokenClassification,
TFFlaubertForMultipleChoice,
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
class TFFlaubertModelTester:
def __init__(
self, parent,
):
self.parent = parent
self.batch_size = 13
self.seq_length = 7
self.is_training = True
self.use_input_lengths = True
self.use_token_type_ids = True
self.use_labels = True
self.gelu_activation = True
self.sinusoidal_embeddings = False
self.causal = False
self.asm = False
self.n_langs = 2
self.vocab_size = 99
self.n_special = 0
self.hidden_size = 32
self.num_hidden_layers = 5
self.num_attention_heads = 4
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 512
self.type_vocab_size = 16
self.type_sequence_label_size = 2
self.initializer_range = 0.02
self.num_labels = 3
self.num_choices = 4
self.summary_type = "last"
self.use_proj = True
self.scope = None
self.bos_token_id = 0
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = ids_tensor([self.batch_size, self.seq_length], 2, dtype=tf.float32)
input_lengths = None
if self.use_input_lengths:
input_lengths = (
ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2
) # small variation of seq_length
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.n_langs)
sequence_labels = None
token_labels = None
is_impossible_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
is_impossible_labels = ids_tensor([self.batch_size], 2, dtype=tf.float32)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = FlaubertConfig(
vocab_size=self.vocab_size,
n_special=self.n_special,
emb_dim=self.hidden_size,
n_layers=self.num_hidden_layers,
n_heads=self.num_attention_heads,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
gelu_activation=self.gelu_activation,
sinusoidal_embeddings=self.sinusoidal_embeddings,
asm=self.asm,
causal=self.causal,
n_langs=self.n_langs,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
summary_type=self.summary_type,
use_proj=self.use_proj,
bos_token_id=self.bos_token_id,
)
return (
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
)
def create_and_check_flaubert_model(
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
):
model = TFFlaubertModel(config=config)
inputs = {"input_ids": input_ids, "lengths": input_lengths, "langs": token_type_ids}
outputs = model(inputs)
inputs = [input_ids, input_mask]
outputs = model(inputs)
sequence_output = outputs[0]
result = {
"sequence_output": sequence_output.numpy(),
}
self.parent.assertListEqual(
list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
def create_and_check_flaubert_lm_head(
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
):
model = TFFlaubertWithLMHeadModel(config)
inputs = {"input_ids": input_ids, "lengths": input_lengths, "langs": token_type_ids}
outputs = model(inputs)
logits = outputs[0]
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
def create_and_check_flaubert_qa(
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
):
model = TFFlaubertForQuestionAnsweringSimple(config)
inputs = {"input_ids": input_ids, "lengths": input_lengths}
start_logits, end_logits = model(inputs)
result = {
"start_logits": start_logits.numpy(),
"end_logits": end_logits.numpy(),
}
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
def create_and_check_flaubert_sequence_classif(
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
):
model = TFFlaubertForSequenceClassification(config)
inputs = {"input_ids": input_ids, "lengths": input_lengths}
(logits,) = model(inputs)
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.type_sequence_label_size])
def create_and_check_flaubert_for_token_classification(
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
):
config.num_labels = self.num_labels
model = TFFlaubertForTokenClassification(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
(logits,) = model(inputs)
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
def create_and_check_flaubert_for_multiple_choice(
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
):
config.num_choices = self.num_choices
model = TFFlaubertForMultipleChoice(config=config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
inputs = {
"input_ids": multiple_choice_inputs_ids,
"attention_mask": multiple_choice_input_mask,
"token_type_ids": multiple_choice_token_type_ids,
}
(logits,) = model(inputs)
result = {"logits": logits.numpy()}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"langs": token_type_ids,
"lengths": input_lengths,
}
return config, inputs_dict
@require_tf
class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
TFFlaubertModel,
TFFlaubertWithLMHeadModel,
TFFlaubertForSequenceClassification,
TFFlaubertForQuestionAnsweringSimple,
TFFlaubertForTokenClassification,
TFFlaubertForMultipleChoice,
)
if is_tf_available()
else ()
)
all_generative_model_classes = (
(TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
def setUp(self):
self.model_tester = TFFlaubertModelTester(self)
self.config_tester = ConfigTester(self, config_class=FlaubertConfig, emb_dim=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_flaubert_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_flaubert_model(*config_and_inputs)
def test_flaubert_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_flaubert_lm_head(*config_and_inputs)
def test_flaubert_qa(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_flaubert_qa(*config_and_inputs)
def test_flaubert_sequence_classif(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_flaubert_sequence_classif(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_flaubert_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_flaubert_for_multiple_choice(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFFlaubertModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_tf @require_tf
......
...@@ -32,6 +32,7 @@ if is_tf_available(): ...@@ -32,6 +32,7 @@ if is_tf_available():
TFXLMForSequenceClassification, TFXLMForSequenceClassification,
TFXLMForQuestionAnsweringSimple, TFXLMForQuestionAnsweringSimple,
TFXLMForTokenClassification, TFXLMForTokenClassification,
TFXLMForMultipleChoice,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST, TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
...@@ -91,6 +92,7 @@ class TFXLMModelTester: ...@@ -91,6 +92,7 @@ class TFXLMModelTester:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
is_impossible_labels = ids_tensor([self.batch_size], 2, dtype=tf.float32) is_impossible_labels = ids_tensor([self.batch_size], 2, dtype=tf.float32)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = XLMConfig( config = XLMConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
...@@ -120,6 +122,7 @@ class TFXLMModelTester: ...@@ -120,6 +122,7 @@ class TFXLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
) )
...@@ -132,6 +135,7 @@ class TFXLMModelTester: ...@@ -132,6 +135,7 @@ class TFXLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
model = TFXLMModel(config=config) model = TFXLMModel(config=config)
...@@ -157,6 +161,7 @@ class TFXLMModelTester: ...@@ -157,6 +161,7 @@ class TFXLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
model = TFXLMWithLMHeadModel(config) model = TFXLMWithLMHeadModel(config)
...@@ -181,6 +186,7 @@ class TFXLMModelTester: ...@@ -181,6 +186,7 @@ class TFXLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
model = TFXLMForQuestionAnsweringSimple(config) model = TFXLMForQuestionAnsweringSimple(config)
...@@ -206,6 +212,7 @@ class TFXLMModelTester: ...@@ -206,6 +212,7 @@ class TFXLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
model = TFXLMForSequenceClassification(config) model = TFXLMForSequenceClassification(config)
...@@ -229,6 +236,7 @@ class TFXLMModelTester: ...@@ -229,6 +236,7 @@ class TFXLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
config.num_labels = self.num_labels config.num_labels = self.num_labels
...@@ -240,6 +248,32 @@ class TFXLMModelTester: ...@@ -240,6 +248,32 @@ class TFXLMModelTester:
} }
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
def create_and_check_xlm_for_multiple_choice(
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
):
config.num_choices = self.num_choices
model = TFXLMForMultipleChoice(config=config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
inputs = {
"input_ids": multiple_choice_inputs_ids,
"attention_mask": multiple_choice_input_mask,
"token_type_ids": multiple_choice_token_type_ids,
}
(logits,) = model(inputs)
result = {"logits": logits.numpy()}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -250,6 +284,7 @@ class TFXLMModelTester: ...@@ -250,6 +284,7 @@ class TFXLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
) = config_and_inputs ) = config_and_inputs
inputs_dict = { inputs_dict = {
...@@ -265,13 +300,13 @@ class TFXLMModelTester: ...@@ -265,13 +300,13 @@ class TFXLMModelTester:
class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
# TODO The multiple choice model is missing and should be added.
( (
TFXLMModel, TFXLMModel,
TFXLMWithLMHeadModel, TFXLMWithLMHeadModel,
TFXLMForSequenceClassification, TFXLMForSequenceClassification,
TFXLMForQuestionAnsweringSimple, TFXLMForQuestionAnsweringSimple,
TFXLMForTokenClassification, TFXLMForTokenClassification,
TFXLMForMultipleChoice,
) )
if is_tf_available() if is_tf_available()
else () else ()
...@@ -307,6 +342,10 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -307,6 +342,10 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_xlm_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -33,6 +33,7 @@ if is_torch_available(): ...@@ -33,6 +33,7 @@ if is_torch_available():
XLMForQuestionAnswering, XLMForQuestionAnswering,
XLMForSequenceClassification, XLMForSequenceClassification,
XLMForQuestionAnsweringSimple, XLMForQuestionAnsweringSimple,
XLMForMultipleChoice,
) )
from transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_LIST
...@@ -63,7 +64,7 @@ class XLMModelTester: ...@@ -63,7 +64,7 @@ class XLMModelTester:
self.max_position_embeddings = 512 self.max_position_embeddings = 512
self.type_sequence_label_size = 2 self.type_sequence_label_size = 2
self.initializer_range = 0.02 self.initializer_range = 0.02
self.num_labels = 3 self.num_labels = 2
self.num_choices = 4 self.num_choices = 4
self.summary_type = "last" self.summary_type = "last"
self.use_proj = True self.use_proj = True
...@@ -91,6 +92,7 @@ class XLMModelTester: ...@@ -91,6 +92,7 @@ class XLMModelTester:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
is_impossible_labels = ids_tensor([self.batch_size], 2).float() is_impossible_labels = ids_tensor([self.batch_size], 2).float()
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = XLMConfig( config = XLMConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
...@@ -109,6 +111,7 @@ class XLMModelTester: ...@@ -109,6 +111,7 @@ class XLMModelTester:
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
summary_type=self.summary_type, summary_type=self.summary_type,
use_proj=self.use_proj, use_proj=self.use_proj,
num_labels=self.num_labels,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
) )
...@@ -120,6 +123,7 @@ class XLMModelTester: ...@@ -120,6 +123,7 @@ class XLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
) )
...@@ -135,6 +139,7 @@ class XLMModelTester: ...@@ -135,6 +139,7 @@ class XLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
model = XLMModel(config=config) model = XLMModel(config=config)
...@@ -160,6 +165,7 @@ class XLMModelTester: ...@@ -160,6 +165,7 @@ class XLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
model = XLMWithLMHeadModel(config) model = XLMWithLMHeadModel(config)
...@@ -185,6 +191,7 @@ class XLMModelTester: ...@@ -185,6 +191,7 @@ class XLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
model = XLMForQuestionAnsweringSimple(config) model = XLMForQuestionAnsweringSimple(config)
...@@ -214,6 +221,7 @@ class XLMModelTester: ...@@ -214,6 +221,7 @@ class XLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
model = XLMForQuestionAnswering(config) model = XLMForQuestionAnswering(config)
...@@ -280,6 +288,7 @@ class XLMModelTester: ...@@ -280,6 +288,7 @@ class XLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
model = XLMForSequenceClassification(config) model = XLMForSequenceClassification(config)
...@@ -306,6 +315,7 @@ class XLMModelTester: ...@@ -306,6 +315,7 @@ class XLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
): ):
config.num_labels = self.num_labels config.num_labels = self.num_labels
...@@ -321,6 +331,38 @@ class XLMModelTester: ...@@ -321,6 +331,38 @@ class XLMModelTester:
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_xlm_for_multiple_choice(
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
choice_labels,
input_mask,
):
config.num_choices = self.num_choices
model = XLMForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -331,6 +373,7 @@ class XLMModelTester: ...@@ -331,6 +373,7 @@ class XLMModelTester:
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
choice_labels,
input_mask, input_mask,
) = config_and_inputs ) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths} inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths}
...@@ -348,6 +391,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -348,6 +391,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
XLMForSequenceClassification, XLMForSequenceClassification,
XLMForQuestionAnsweringSimple, XLMForQuestionAnsweringSimple,
XLMForTokenClassification, XLMForTokenClassification,
XLMForMultipleChoice,
) )
if is_torch_available() if is_torch_available()
else () else ()
...@@ -387,6 +431,10 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -387,6 +431,10 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_token_classif(*config_and_inputs) self.model_tester.create_and_check_xlm_token_classif(*config_and_inputs)
def test_xlm_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment