"examples/trials/vscode:/vscode.git/clone" did not exist on "664186759af6d507c98d1f95986684813a1d5b36"
Unverified Commit 82984215 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Add TFBartForConditionalGeneration (#5411)



* half done

* doc improvement

* Cp test file

* brokedn

* broken test

* undo some mess

* ckpt

* borked

* Halfway

* 6 passing

* boom boom

* Much progress but still 6

* boom boom

* merged master

* 10 passing

* boom boom

* Style

* no t5 changes

* 13 passing

* Integration test failing, but not gibberish

* Frustrated

* Merged master

* 4 fail

* 4 fail

* fix return_dict

* boom boom

* Still only 4

* prepare method

* prepare method

* before delete classif

* Skip tests to avoid adding boilerplate

* boom boom

* fast tests passing

* style

* boom boom

* Switch to supporting many input types

* remove FIXMENORM

* working

* Fixed past_key_values/decoder_cached_states confusion

* new broken test

* Fix attention mask kwarg name

* undo accidental

* Style and reviewers

* style

* Docs and common tests

* Cleaner assert messages

* copy docs

* style issues

* Sphinx fix

* Simplify caching logic

* test does not require torch

* copy _NoLayerEmbedTokens

* Update src/transformers/modeling_tf_bart.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Update tests/test_modeling_tf_bart.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_tf_bart.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_tf_bart.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_tf_bart.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Line length and dont document None

* Add pipeline test coverage

* assert msg

* At parity

* Assert messages

* mark slow

* Update compile test

* back in init

* Merge master

* Fix tests
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 5cd9e2cb
......@@ -86,3 +86,18 @@ BartForQuestionAnswering
.. autoclass:: transformers.BartForQuestionAnswering
:members: forward
TFBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFBartModel
:members: call
TFBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFBartForConditionalGeneration
:members: call
......@@ -652,6 +652,7 @@ if is_tf_available():
TFAutoModelForTokenClassification,
TFAutoModelWithLMHead,
)
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel
from .modeling_tf_bert import (
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBertEmbeddings,
......
......@@ -20,6 +20,7 @@ import os
from transformers import (
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
......@@ -37,6 +38,7 @@ from transformers import (
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
AlbertConfig,
BartConfig,
BertConfig,
CamembertConfig,
CTRLConfig,
......@@ -49,6 +51,7 @@ from transformers import (
RobertaConfig,
T5Config,
TFAlbertForPreTraining,
TFBartForConditionalGeneration,
TFBertForPreTraining,
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
......@@ -87,6 +90,7 @@ if is_torch_available():
from transformers import (
AlbertForPreTraining,
BartForConditionalGeneration,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
......@@ -113,6 +117,12 @@ if is_torch_available():
logging.set_verbosity_info()
MODEL_CLASSES = {
"bart": (
BartConfig,
TFBartForConditionalGeneration,
BartForConditionalGeneration,
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
),
"bert": (
BertConfig,
TFBertForPreTraining,
......
......@@ -640,6 +640,10 @@ class TFGenerationMixin:
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
if self.config.is_encoder_decoder and do_sample is False:
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
# calculate log softmax score
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
......@@ -890,6 +894,13 @@ class TFGenerationMixin:
def _reorder_cache(past, beam_idx):
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
def adjust_logits_during_generation(self, logits, **kwargs):
"""
Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to adjust the logits in
the generate method.
"""
return logits
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
# create logit penalties for already seen input_ids
......
......@@ -131,7 +131,7 @@ BART_INPUTS_DOCSTRING = r"""
:obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`) is a
sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
the decoder.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
past_key_values (:obj:`Tuple[Dict[str: tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last
......@@ -217,12 +217,6 @@ def _make_linear_from_emb(emb):
return lin_layer
# Helper Functions, mostly for making masks
def _check_shapes(shape_1, shape2):
if shape_1 != shape2:
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
def shift_tokens_right(input_ids, pad_token_id):
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
prev_output_tokens = input_ids.clone()
......@@ -595,7 +589,7 @@ class BartDecoder(nn.Module):
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = []
next_decoder_cache: List[Dict] = []
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
......@@ -640,7 +634,7 @@ class BartDecoder(nn.Module):
)
def _reorder_buffer(attn_cache, new_order):
def _reorder_buffer(attn_cache: Dict, new_order) -> Dict:
for k, input_buffer_k in attn_cache.items():
if input_buffer_k is not None:
attn_cache[k] = input_buffer_k.index_select(0, new_order)
......@@ -679,17 +673,15 @@ class Attention(nn.Module):
def forward(
self,
query,
key: Optional[Tensor],
key: Tensor,
key_padding_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
layer_state: Optional[Dict[str, Tensor]] = None,
attn_mask: Optional[Tensor] = None,
output_attentions=False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time(SeqLen) x Batch x Channel"""
static_kv: bool = self.encoder_decoder_attention
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
# get here for encoder decoder cause of static_kv
if layer_state is not None: # reuse k,v and encoder_padding_mask
saved_state = layer_state.get(self.cache_key, {})
......@@ -697,17 +689,16 @@ class Attention(nn.Module):
# previous time steps are cached - no need to recompute key and value if they are static
key = None
else:
# this branch is hit by encoder
saved_state = None
layer_state = {}
q = self.q_proj(query) * self.scaling
if static_kv:
if key is None:
if static_kv and key is None: # cross-attention with cache
k = v = None
else:
elif static_kv and key is not None: # cross-attention no prev_key found in cache
k = self.k_proj(key)
v = self.v_proj(key)
else:
else: # self-attention
k = self.k_proj(query)
v = self.v_proj(query)
......@@ -717,18 +708,16 @@ class Attention(nn.Module):
if v is not None:
v = self._shape(v, -1, bsz)
if saved_state is not None:
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
if saved_state:
k, v = self._concat_saved_state(k, v, saved_state, static_kv, bsz)
# Update cache
layer_state[self.cache_key] = {
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
}
if isinstance(layer_state, dict):
cached_shape = (bsz, self.num_heads, -1, self.head_dim) # bsz must be first for reorder_cache
layer_state[self.cache_key] = dict(prev_key=k.view(*cached_shape), prev_value=v.view(*cached_shape))
assert k is not None
src_len = k.size(1)
assert key_padding_mask is None or key_padding_mask.shape == (bsz, src_len)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
......@@ -736,13 +725,7 @@ class Attention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
assert key_padding_mask is None or key_padding_mask.size()[:2] == (
bsz,
src_len,
)
# Note: deleted workaround to get around fork/join parallelism not supporting Optional types. on 2020/10/15
if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
......@@ -750,11 +733,7 @@ class Attention(nn.Module):
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout(
attn_weights,
p=self.dropout,
training=self.training,
)
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
assert v is not None
attn_output = torch.bmm(attn_probs, v)
......@@ -767,36 +746,13 @@ class Attention(nn.Module):
attn_weights = None
return attn_output, attn_weights
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
def _concat_saved_state(self, k, v, saved_state, static_kv, bsz) -> Tuple[Tensor]:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
assert k is not None and v is not None
prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
if prev_key_padding_mask is not None:
if static_kv:
new_key_padding_mask = prev_key_padding_mask
else:
new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
else:
new_key_padding_mask = key_padding_mask
return k, v, new_key_padding_mask
prev_K = saved_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim)
prev_V = saved_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim)
new_K = prev_K if static_kv else torch.cat([prev_K, k], dim=1)
new_V = prev_V if static_kv else torch.cat([prev_V, v], dim=1)
return new_K, new_V
class BartClassificationHead(nn.Module):
......@@ -1143,14 +1099,15 @@ class BartForConditionalGeneration(PretrainedBartModel):
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
self._force_token_ids_generation(logits, self.config.bos_token_id)
self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id)
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
return logits
def _force_token_ids_generation(self, scores, token_id) -> None:
@staticmethod
def _force_token_id_to_be_generated(scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
@staticmethod
def _reorder_cache(past, beam_idx):
......
......@@ -52,5 +52,5 @@ class BlenderbotForConditionalGeneration(BartForConditionalGeneration):
def adjust_logits_during_generation(self, logits, cur_len, max_length):
logits[:, self.config.bos_token_id] = -torch.finfo(torch.float16).max # near infinity fp16
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id)
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
return logits
......@@ -51,5 +51,5 @@ class MarianMTModel(BartForConditionalGeneration):
def adjust_logits_during_generation(self, logits, cur_len, max_length):
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id)
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
return logits
......@@ -21,6 +21,7 @@ from collections import OrderedDict
from .configuration_auto import (
AlbertConfig,
AutoConfig,
BartConfig,
BertConfig,
CamembertConfig,
CTRLConfig,
......@@ -51,6 +52,7 @@ from .modeling_tf_albert import (
TFAlbertForTokenClassification,
TFAlbertModel,
)
from .modeling_tf_bart import TFBartForConditionalGeneration
from .modeling_tf_bert import (
TFBertForMaskedLM,
TFBertForMultipleChoice,
......@@ -206,6 +208,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForMaskedLM),
(BartConfig, TFBartForConditionalGeneration),
(CamembertConfig, TFCamembertForMaskedLM),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(LongformerConfig, TFLongformerForMaskedLM),
......@@ -256,7 +259,9 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
]
)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict([(T5Config, TFT5ForConditionalGeneration)])
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[(T5Config, TFT5ForConditionalGeneration), (BartConfig, TFBartForConditionalGeneration)]
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[
......
This diff is collapsed.
......@@ -717,7 +717,7 @@ BERT_START_DOCSTRING = r"""
Args:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
Check out the :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the model weights.
"""
BERT_INPUTS_DOCSTRING = r"""
......
......@@ -229,7 +229,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
else:
logger.warning(
f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
f"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
)
......@@ -383,7 +383,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
else:
logger.warning(
f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
f"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {pt_model.__class__.__name__} for predictions without further training."
)
......
......@@ -20,6 +20,7 @@ import copy
import itertools
import math
import warnings
from typing import Tuple
import tensorflow as tf
......@@ -594,7 +595,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
training=False,
**kwargs,
):
) -> Tuple:
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
......@@ -699,7 +700,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# positions we want to attend and -1e9 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
......@@ -721,7 +722,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
......@@ -1417,7 +1418,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
"use_cache": use_cache,
}
def _reorder_cache(self, past, beam_idx):
def _reorder_cache(self, past, beam_idx) -> Tuple:
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
......
......@@ -136,8 +136,7 @@ class TFCausalLanguageModelingLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100
# are taken into account as loss
# make sure only labels that are not equal to -100 do not affect loss
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
......
......@@ -1945,11 +1945,6 @@ class SummarizationPipeline(Pipeline):
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
assert len(documents) > 0, "Please provide a document to summarize"
if self.framework == "tf" and "BartForConditionalGeneration" in self.model.__class__.__name__:
raise NotImplementedError(
"Tensorflow is not yet supported for Bart. Please consider using T5, e.g. `t5-base`"
)
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
if isinstance(documents[0], list):
......
......@@ -212,6 +212,24 @@ class TFAutoModelWithLMHead:
requires_tf(self)
class TFBartForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFBartModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
This diff is collapsed.
This diff is collapsed.
......@@ -302,7 +302,7 @@ class TFModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
config.output_hidden_states = True
......@@ -472,10 +472,9 @@ class TFModelTesterMixin:
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname:
outputs = model(self._prepare_for_class(inputs_dict, model_class)) # build the model
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
......@@ -494,7 +493,9 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes:
model = model_class(config)
outputs_dict = model(self._prepare_for_class(inputs_dict, model_class))
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs_dict = model(inputs)
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
input_ids = inputs_keywords.pop("input_ids", None)
......@@ -507,28 +508,18 @@ class TFModelTesterMixin:
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
decoder_seq_length = (
self.model_tester.decoder_seq_length
if hasattr(self.model_tester, "decoder_seq_length")
else self.model_tester.seq_length
)
encoder_seq_length = (
self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "encoder_seq_length")
else self.model_tester.seq_length
)
decoder_key_length = (
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else decoder_seq_length
)
encoder_key_length = (
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", self.model_tester.seq_length)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["use_cache"] = False
config.output_hidden_states = False
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
model_inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(model_inputs)
attentions = [t.numpy() for t in outputs[-1]]
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
......
......@@ -279,8 +279,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
@slow
def test_model_from_pretrained(self):
for model_name in ["t5-small"]:
model = TFT5Model.from_pretrained(model_name)
model = TFT5Model.from_pretrained("t5-small")
self.assertIsNotNone(model)
......
......@@ -21,7 +21,7 @@ FILL_MASK_FINETUNED_MODELS = ["sshleifer/tiny-distilroberta-base"]
LARGE_FILL_MASK_FINETUNED_MODELS = ["distilroberta-base"] # @slow
SUMMARIZATION_FINETUNED_MODELS = ["sshleifer/bart-tiny-random", "patrickvonplaten/t5-tiny-random"]
TF_SUMMARIZATION_FINETUNED_MODELS = ["patrickvonplaten/t5-tiny-random"]
TF_SUMMARIZATION_FINETUNED_MODELS = ["sshleifer/bart-tiny-random", "patrickvonplaten/t5-tiny-random"]
TRANSLATION_FINETUNED_MODELS = [
("patrickvonplaten/t5-tiny-random", "translation_en_to_de"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment