"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c8b6052ff681e3ca8dab168dfd524b9fbbceb5bd"
Unverified Commit fc95386e authored by Cole Howard's avatar Cole Howard Committed by GitHub
Browse files

Add TFBartForSequenceClassification (#20570)

* read to load

* base functionality

* revert init

* fix dummy data

* moving right along

* moving right along

* finally

* cleanup

* pull out comment

* add test

* update docstring for main class

* flake comments and rewriting copies from make repo-consistency`

* remove irrelevant differences/accidental spaces

* put copies back after space removals

* mid

* final test pass

* stray comment

* update test file

* update test file

* fixup

* black

* missed

* black missed one more

* sytle

* add doc update

* fix order of output class

* comment

* Revert "comment"

This reverts commit 03f86b6948808461939cc8ad4ad74305dfb67700.

* remove redundant function, and redundant reshape

* move change out of common

* style

* put common spaces back

* reorder kwargs in output

* doc style
parent 77382e91
...@@ -157,6 +157,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h ...@@ -157,6 +157,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] TFBartForConditionalGeneration [[autodoc]] TFBartForConditionalGeneration
- call - call
## TFBartForSequenceClassification
[[autodoc]] TFBartForSequenceClassification
- call
## FlaxBartModel ## FlaxBartModel
[[autodoc]] FlaxBartModel [[autodoc]] FlaxBartModel
......
...@@ -2513,7 +2513,9 @@ else: ...@@ -2513,7 +2513,9 @@ else:
"TFAutoModelWithLMHead", "TFAutoModelWithLMHead",
] ]
) )
_import_structure["models.bart"].extend(["TFBartForConditionalGeneration", "TFBartModel", "TFBartPretrainedModel"]) _import_structure["models.bart"].extend(
["TFBartForConditionalGeneration", "TFBartForSequenceClassification", "TFBartModel", "TFBartPretrainedModel"]
)
_import_structure["models.bert"].extend( _import_structure["models.bert"].extend(
[ [
"TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -5402,7 +5404,12 @@ if TYPE_CHECKING: ...@@ -5402,7 +5404,12 @@ if TYPE_CHECKING:
TFAutoModelForVision2Seq, TFAutoModelForVision2Seq,
TFAutoModelWithLMHead, TFAutoModelWithLMHead,
) )
from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel from .models.bart import (
TFBartForConditionalGeneration,
TFBartForSequenceClassification,
TFBartModel,
TFBartPretrainedModel,
)
from .models.bert import ( from .models.bert import (
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBertEmbeddings, TFBertEmbeddings,
......
...@@ -58,6 +58,7 @@ from . import ( ...@@ -58,6 +58,7 @@ from . import (
T5Config, T5Config,
TFAlbertForPreTraining, TFAlbertForPreTraining,
TFBartForConditionalGeneration, TFBartForConditionalGeneration,
TFBartForSequenceClassification,
TFBertForPreTraining, TFBertForPreTraining,
TFBertForQuestionAnswering, TFBertForQuestionAnswering,
TFBertForSequenceClassification, TFBertForSequenceClassification,
...@@ -136,6 +137,7 @@ MODEL_CLASSES = { ...@@ -136,6 +137,7 @@ MODEL_CLASSES = {
"bart": ( "bart": (
BartConfig, BartConfig,
TFBartForConditionalGeneration, TFBartForConditionalGeneration,
TFBartForSequenceClassification,
BartForConditionalGeneration, BartForConditionalGeneration,
BART_PRETRAINED_MODEL_ARCHIVE_LIST, BART_PRETRAINED_MODEL_ARCHIVE_LIST,
), ),
......
...@@ -623,6 +623,9 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput): ...@@ -623,6 +623,9 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`
encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder of the model. Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
...@@ -643,6 +646,7 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput): ...@@ -643,6 +646,7 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
past_key_values: Optional[List[tf.Tensor]] = None past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None encoder_last_hidden_state: Optional[tf.Tensor] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_attentions: Optional[Tuple[tf.Tensor]] = None
......
...@@ -1190,7 +1190,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1190,7 +1190,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return self.serving_output(output) return self.serving_output(output)
def serving_output(output): def serving_output(self, output):
""" """
Prepare the output of the saved model. Each model must implement this function. Prepare the output of the saved model. Each model must implement this function.
......
...@@ -268,6 +268,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -268,6 +268,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Sequence Classification mapping # Model for Sequence Classification mapping
("albert", "TFAlbertForSequenceClassification"), ("albert", "TFAlbertForSequenceClassification"),
("bart", "TFBartForSequenceClassification"),
("bert", "TFBertForSequenceClassification"), ("bert", "TFBertForSequenceClassification"),
("camembert", "TFCamembertForSequenceClassification"), ("camembert", "TFCamembertForSequenceClassification"),
("convbert", "TFConvBertForSequenceClassification"), ("convbert", "TFConvBertForSequenceClassification"),
......
...@@ -63,7 +63,12 @@ try: ...@@ -63,7 +63,12 @@ try:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
else: else:
_import_structure["modeling_tf_bart"] = ["TFBartForConditionalGeneration", "TFBartModel", "TFBartPretrainedModel"] _import_structure["modeling_tf_bart"] = [
"TFBartForConditionalGeneration",
"TFBartForSequenceClassification",
"TFBartModel",
"TFBartPretrainedModel",
]
try: try:
if not is_flax_available(): if not is_flax_available():
...@@ -116,7 +121,12 @@ if TYPE_CHECKING: ...@@ -116,7 +121,12 @@ if TYPE_CHECKING:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
else: else:
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel from .modeling_tf_bart import (
TFBartForConditionalGeneration,
TFBartForSequenceClassification,
TFBartModel,
TFBartPretrainedModel,
)
try: try:
if not is_flax_available(): if not is_flax_available():
......
...@@ -27,6 +27,7 @@ from ...modeling_tf_outputs import ( ...@@ -27,6 +27,7 @@ from ...modeling_tf_outputs import (
TFBaseModelOutputWithPastAndCrossAttentions, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
TFSeq2SeqSequenceClassifierOutput,
) )
# Public API # Public API
...@@ -35,6 +36,7 @@ from ...modeling_tf_utils import ( ...@@ -35,6 +36,7 @@ from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFModelInputType, TFModelInputType,
TFPreTrainedModel, TFPreTrainedModel,
TFSequenceClassificationLoss,
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
...@@ -460,6 +462,24 @@ class TFBartDecoderLayer(tf.keras.layers.Layer): ...@@ -460,6 +462,24 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
) )
class TFBartClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks."""
def __init__(self, inner_dim: int, num_classes: int, pooler_dropout: float, name: str, **kwargs):
super().__init__(name=name, **kwargs)
self.dense = tf.keras.layers.Dense(inner_dim, name="dense")
self.dropout = tf.keras.layers.Dropout(pooler_dropout)
self.out_proj = tf.keras.layers.Dense(num_classes, name="out_proj")
def call(self, inputs):
hidden_states = self.dropout(inputs)
hidden_states = self.dense(hidden_states)
hidden_states = tf.keras.activations.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class TFBartPretrainedModel(TFPreTrainedModel): class TFBartPretrainedModel(TFPreTrainedModel):
config_class = BartConfig config_class = BartConfig
base_model_prefix = "model" base_model_prefix = "model"
...@@ -726,7 +746,6 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -726,7 +746,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -1465,3 +1484,141 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1465,3 +1484,141 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
) )
return reordered_past return reordered_past
@add_start_docstrings(
"""
Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
tasks.
""",
BART_START_DOCSTRING,
)
class TFBartForSequenceClassification(TFBartPretrainedModel, TFSequenceClassificationLoss):
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = tf.constant([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]])
dummy_inputs = {
"attention_mask": tf.cast(tf.math.not_equal(input_ids, (pad_token)), dtype=tf.int32),
"input_ids": input_ids,
}
return dummy_inputs
def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model")
self.classification_head = TFBartClassificationHead(
config.d_model, config.num_labels, config.classifier_dropout, name="classification_head"
)
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call(
self,
input_ids: Optional[TFModelInputType] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[tf.Tensor] = None,
training: Optional[bool] = False,
) -> Union[TFSeq2SeqSequenceClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
if input_ids is None and inputs_embeds is not None:
raise NotImplementedError(
f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
last_hidden_state = outputs[0]
eos_mask = tf.equal(input_ids, self.config.eos_token_id)
# out the rows with False where present. Then verify all the final
# entries are True
self_masked = tf.reshape(tf.boolean_mask(eos_mask, eos_mask), (tf.shape(input_ids)[0], -1))
tf.Assert(tf.reduce_all(self_masked[:, -1]), ["All examples must have the same number of <eos> tokens."])
masked = tf.reshape(
tf.boolean_mask(last_hidden_state, eos_mask),
(tf.shape(input_ids)[0], tf.shape(self_masked)[1], tf.shape(last_hidden_state)[-1]),
)
sentence_representation = masked[:, -1, :]
logits = self.classification_head(sentence_representation)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TFSeq2SeqSequenceClassifierOutput(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def serving_output(self, output):
logits = tf.convert_to_tensor(output.logits)
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
return TFSeq2SeqSequenceClassifierOutput(
logits=logits,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns,
)
...@@ -449,6 +449,13 @@ class TFBartForConditionalGeneration(metaclass=DummyObject): ...@@ -449,6 +449,13 @@ class TFBartForConditionalGeneration(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFBartForSequenceClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBartModel(metaclass=DummyObject): class TFBartModel(metaclass=DummyObject):
_backends = ["tf"] _backends = ["tf"]
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import tempfile
import unittest import unittest
import numpy as np import numpy as np
...@@ -29,7 +31,7 @@ from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin ...@@ -29,7 +31,7 @@ from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers import TFBartForConditionalGeneration, TFBartModel from transformers import TFBartForConditionalGeneration, TFBartForSequenceClassification, TFBartModel
@require_tf @require_tf
...@@ -76,7 +78,13 @@ class TFBartModelTester: ...@@ -76,7 +78,13 @@ class TFBartModelTester:
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size) # Ids are clipped to avoid "beginng of sequence", "end of sequence", and "pad" tokens
input_ids = tf.clip_by_value(
ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size),
clip_value_min=self.eos_token_id + 1,
clip_value_max=self.vocab_size + 1,
)
# Explicity add "end of sequence" to the inputs
eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1) eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
input_ids = tf.concat([input_ids, eos_tensor], axis=1) input_ids = tf.concat([input_ids, eos_tensor], axis=1)
...@@ -181,7 +189,9 @@ def prepare_bart_inputs_dict( ...@@ -181,7 +189,9 @@ def prepare_bart_inputs_dict(
@require_tf @require_tf
class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestCase): class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else () all_model_classes = (
(TFBartForConditionalGeneration, TFBartForSequenceClassification, TFBartModel) if is_tf_available() else ()
)
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
...@@ -228,6 +238,119 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC ...@@ -228,6 +238,119 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
def test_onnx_compliancy(self): def test_onnx_compliancy(self):
pass pass
# TFBartForSequenceClassification does not support inputs_embeds
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in (TFBartForConditionalGeneration, TFBartModel):
model = model_class(config)
inputs = copy.deepcopy(inputs_dict)
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
del inputs["input_ids"]
else:
encoder_input_ids = inputs["input_ids"]
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
del inputs["input_ids"]
inputs.pop("decoder_input_ids", None)
if not self.is_encoder_decoder:
inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
else:
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
inputs = self._prepare_for_class(inputs, model_class)
model(inputs)
# TFBartForSequenceClassification does not support inputs_embeds
@slow
def test_graph_mode_with_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in (TFBartForConditionalGeneration, TFBartModel):
model = model_class(config)
inputs = copy.deepcopy(inputs_dict)
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
del inputs["input_ids"]
else:
encoder_input_ids = inputs["input_ids"]
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
del inputs["input_ids"]
inputs.pop("decoder_input_ids", None)
if not self.is_encoder_decoder:
inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
else:
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
inputs = self._prepare_for_class(inputs, model_class)
@tf.function
def run_in_graph_mode():
return model(inputs)
outputs = run_in_graph_mode()
self.assertIsNotNone(outputs)
@slow
def test_save_load_after_resize_token_embeddings(self):
# Custom version of this test to ensure "end of sequence" tokens are present throughout
if not self.test_resize_embeddings:
return
config, original_inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# create a model with resized (expended) embeddings
new_tokens_size = 10
old_total_size = config.vocab_size
new_total_size = old_total_size + new_tokens_size
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
model(model.dummy_inputs) # builds the embeddings layer
model.resize_token_embeddings(new_total_size)
# fetch the output for an input exclusively made of new members of the vocabulary
inputs_dict = copy.deepcopy(original_inputs_dict)
ids_feat_name = None
if "input_ids" in inputs_dict:
ids_feat_name = "input_ids"
elif "decoder_input_ids" in inputs_dict:
ids_feat_name = "decoder_input_ids"
else:
assert False, "No input ids feature found in the inputs dict"
new_vocab_input_ids = ids_tensor(inputs_dict[ids_feat_name].shape, new_tokens_size)
new_vocab_input_ids += old_total_size
# Replace last id with EOS token
new_vocab_input_ids = new_vocab_input_ids[:, :-1]
new_vocab_input_ids = tf.concat(
[new_vocab_input_ids, tf.ones((tf.shape(new_vocab_input_ids)[0], 1), dtype=tf.int32) * 2], axis=1
)
inputs_dict[ids_feat_name] = new_vocab_input_ids
if "input_ids" in inputs_dict:
inputs_dict["input_ids"] = new_vocab_input_ids
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"] = new_vocab_input_ids
prepared_inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**prepared_inputs)
# save and load the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=False)
model = model_class.from_pretrained(tmpdirname)
restored_model_outputs = model(**prepared_inputs)
# check that the output for the restored model is the same
self.assert_outputs_same(restored_model_outputs, outputs)
def _long_tensor(tok_lst): def _long_tensor(tok_lst):
return tf.constant(tok_lst, dtype=tf.int32) return tf.constant(tok_lst, dtype=tf.int32)
...@@ -286,6 +409,19 @@ class TFBartHeadTests(unittest.TestCase): ...@@ -286,6 +409,19 @@ class TFBartHeadTests(unittest.TestCase):
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
@require_tf
class TFBartForSequenceClassificationTest(unittest.TestCase):
def test_model_fails_for_uneven_eos_tokens(self):
config = BartConfig(eos_token_id=2)
model = TFBartForSequenceClassification(config)
inputs = {
"input_ids": tf.constant([[1, 2, 2, 2], [1, 3, 2, 2], [2, 2, 3, 3]]),
"attention_mask": tf.constant([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]),
}
with self.assertRaises(tf.errors.InvalidArgumentError):
model(inputs)
@slow @slow
@require_tf @require_tf
class TFBartModelIntegrationTest(unittest.TestCase): class TFBartModelIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment