Unverified Commit 82984215 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Add TFBartForConditionalGeneration (#5411)



* half done

* doc improvement

* Cp test file

* brokedn

* broken test

* undo some mess

* ckpt

* borked

* Halfway

* 6 passing

* boom boom

* Much progress but still 6

* boom boom

* merged master

* 10 passing

* boom boom

* Style

* no t5 changes

* 13 passing

* Integration test failing, but not gibberish

* Frustrated

* Merged master

* 4 fail

* 4 fail

* fix return_dict

* boom boom

* Still only 4

* prepare method

* prepare method

* before delete classif

* Skip tests to avoid adding boilerplate

* boom boom

* fast tests passing

* style

* boom boom

* Switch to supporting many input types

* remove FIXMENORM

* working

* Fixed past_key_values/decoder_cached_states confusion

* new broken test

* Fix attention mask kwarg name

* undo accidental

* Style and reviewers

* style

* Docs and common tests

* Cleaner assert messages

* copy docs

* style issues

* Sphinx fix

* Simplify caching logic

* test does not require torch

* copy _NoLayerEmbedTokens

* Update src/transformers/modeling_tf_bart.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Update tests/test_modeling_tf_bart.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_tf_bart.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_tf_bart.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_tf_bart.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Line length and dont document None

* Add pipeline test coverage

* assert msg

* At parity

* Assert messages

* mark slow

* Update compile test

* back in init

* Merge master

* Fix tests
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 5cd9e2cb
...@@ -86,3 +86,18 @@ BartForQuestionAnswering ...@@ -86,3 +86,18 @@ BartForQuestionAnswering
.. autoclass:: transformers.BartForQuestionAnswering .. autoclass:: transformers.BartForQuestionAnswering
:members: forward :members: forward
TFBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFBartModel
:members: call
TFBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFBartForConditionalGeneration
:members: call
...@@ -652,6 +652,7 @@ if is_tf_available(): ...@@ -652,6 +652,7 @@ if is_tf_available():
TFAutoModelForTokenClassification, TFAutoModelForTokenClassification,
TFAutoModelWithLMHead, TFAutoModelWithLMHead,
) )
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel
from .modeling_tf_bert import ( from .modeling_tf_bert import (
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBertEmbeddings, TFBertEmbeddings,
......
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
from transformers import ( from transformers import (
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -37,6 +38,7 @@ from transformers import ( ...@@ -37,6 +38,7 @@ from transformers import (
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
AlbertConfig, AlbertConfig,
BartConfig,
BertConfig, BertConfig,
CamembertConfig, CamembertConfig,
CTRLConfig, CTRLConfig,
...@@ -49,6 +51,7 @@ from transformers import ( ...@@ -49,6 +51,7 @@ from transformers import (
RobertaConfig, RobertaConfig,
T5Config, T5Config,
TFAlbertForPreTraining, TFAlbertForPreTraining,
TFBartForConditionalGeneration,
TFBertForPreTraining, TFBertForPreTraining,
TFBertForQuestionAnswering, TFBertForQuestionAnswering,
TFBertForSequenceClassification, TFBertForSequenceClassification,
...@@ -87,6 +90,7 @@ if is_torch_available(): ...@@ -87,6 +90,7 @@ if is_torch_available():
from transformers import ( from transformers import (
AlbertForPreTraining, AlbertForPreTraining,
BartForConditionalGeneration,
BertForPreTraining, BertForPreTraining,
BertForQuestionAnswering, BertForQuestionAnswering,
BertForSequenceClassification, BertForSequenceClassification,
...@@ -113,6 +117,12 @@ if is_torch_available(): ...@@ -113,6 +117,12 @@ if is_torch_available():
logging.set_verbosity_info() logging.set_verbosity_info()
MODEL_CLASSES = { MODEL_CLASSES = {
"bart": (
BartConfig,
TFBartForConditionalGeneration,
BartForConditionalGeneration,
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
),
"bert": ( "bert": (
BertConfig, BertConfig,
TFBertForPreTraining, TFBertForPreTraining,
......
...@@ -640,6 +640,10 @@ class TFGenerationMixin: ...@@ -640,6 +640,10 @@ class TFGenerationMixin:
if temperature != 1.0: if temperature != 1.0:
next_token_logits = next_token_logits / temperature next_token_logits = next_token_logits / temperature
if self.config.is_encoder_decoder and do_sample is False:
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
# calculate log softmax score # calculate log softmax score
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
...@@ -890,6 +894,13 @@ class TFGenerationMixin: ...@@ -890,6 +894,13 @@ class TFGenerationMixin:
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past) return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
def adjust_logits_during_generation(self, logits, **kwargs):
"""
Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to adjust the logits in
the generate method.
"""
return logits
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
# create logit penalties for already seen input_ids # create logit penalties for already seen input_ids
......
...@@ -131,7 +131,7 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -131,7 +131,7 @@ BART_INPUTS_DOCSTRING = r"""
:obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`) is a :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`) is a
sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
the decoder. the decoder.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Dict[str: tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last If :obj:`past_key_values` are used, the user can optionally input only the last
...@@ -217,12 +217,6 @@ def _make_linear_from_emb(emb): ...@@ -217,12 +217,6 @@ def _make_linear_from_emb(emb):
return lin_layer return lin_layer
# Helper Functions, mostly for making masks
def _check_shapes(shape_1, shape2):
if shape_1 != shape2:
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
def shift_tokens_right(input_ids, pad_token_id): def shift_tokens_right(input_ids, pad_token_id):
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).""" """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
prev_output_tokens = input_ids.clone() prev_output_tokens = input_ids.clone()
...@@ -595,7 +589,7 @@ class BartDecoder(nn.Module): ...@@ -595,7 +589,7 @@ class BartDecoder(nn.Module):
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = [] next_decoder_cache: List[Dict] = []
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
...@@ -640,7 +634,7 @@ class BartDecoder(nn.Module): ...@@ -640,7 +634,7 @@ class BartDecoder(nn.Module):
) )
def _reorder_buffer(attn_cache, new_order): def _reorder_buffer(attn_cache: Dict, new_order) -> Dict:
for k, input_buffer_k in attn_cache.items(): for k, input_buffer_k in attn_cache.items():
if input_buffer_k is not None: if input_buffer_k is not None:
attn_cache[k] = input_buffer_k.index_select(0, new_order) attn_cache[k] = input_buffer_k.index_select(0, new_order)
...@@ -679,17 +673,15 @@ class Attention(nn.Module): ...@@ -679,17 +673,15 @@ class Attention(nn.Module):
def forward( def forward(
self, self,
query, query,
key: Optional[Tensor], key: Tensor,
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Optional[Tensor]]] = None, layer_state: Optional[Dict[str, Tensor]] = None,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
output_attentions=False, output_attentions=False,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time(SeqLen) x Batch x Channel""" """Input shape: Time(SeqLen) x Batch x Channel"""
static_kv: bool = self.encoder_decoder_attention static_kv: bool = self.encoder_decoder_attention
tgt_len, bsz, embed_dim = query.size() tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
# get here for encoder decoder cause of static_kv # get here for encoder decoder cause of static_kv
if layer_state is not None: # reuse k,v and encoder_padding_mask if layer_state is not None: # reuse k,v and encoder_padding_mask
saved_state = layer_state.get(self.cache_key, {}) saved_state = layer_state.get(self.cache_key, {})
...@@ -697,17 +689,16 @@ class Attention(nn.Module): ...@@ -697,17 +689,16 @@ class Attention(nn.Module):
# previous time steps are cached - no need to recompute key and value if they are static # previous time steps are cached - no need to recompute key and value if they are static
key = None key = None
else: else:
# this branch is hit by encoder
saved_state = None saved_state = None
layer_state = {}
q = self.q_proj(query) * self.scaling q = self.q_proj(query) * self.scaling
if static_kv: if static_kv and key is None: # cross-attention with cache
if key is None: k = v = None
k = v = None elif static_kv and key is not None: # cross-attention no prev_key found in cache
else: k = self.k_proj(key)
k = self.k_proj(key) v = self.v_proj(key)
v = self.v_proj(key) else: # self-attention
else:
k = self.k_proj(query) k = self.k_proj(query)
v = self.v_proj(query) v = self.v_proj(query)
...@@ -717,18 +708,16 @@ class Attention(nn.Module): ...@@ -717,18 +708,16 @@ class Attention(nn.Module):
if v is not None: if v is not None:
v = self._shape(v, -1, bsz) v = self._shape(v, -1, bsz)
if saved_state is not None: if saved_state:
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz) k, v = self._concat_saved_state(k, v, saved_state, static_kv, bsz)
# Update cache # Update cache
layer_state[self.cache_key] = { if isinstance(layer_state, dict):
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim), cached_shape = (bsz, self.num_heads, -1, self.head_dim) # bsz must be first for reorder_cache
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim), layer_state[self.cache_key] = dict(prev_key=k.view(*cached_shape), prev_value=v.view(*cached_shape))
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
}
assert k is not None
src_len = k.size(1) src_len = k.size(1)
assert key_padding_mask is None or key_padding_mask.shape == (bsz, src_len)
attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = torch.bmm(q, k.transpose(1, 2))
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len) assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
...@@ -736,13 +725,7 @@ class Attention(nn.Module): ...@@ -736,13 +725,7 @@ class Attention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# This is part of a workaround to get around fork/join parallelism not supporting Optional types. # Note: deleted workaround to get around fork/join parallelism not supporting Optional types. on 2020/10/15
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
assert key_padding_mask is None or key_padding_mask.size()[:2] == (
bsz,
src_len,
)
if key_padding_mask is not None: # don't attend to padding symbols if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
...@@ -750,11 +733,7 @@ class Attention(nn.Module): ...@@ -750,11 +733,7 @@ class Attention(nn.Module):
attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout( attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn_weights,
p=self.dropout,
training=self.training,
)
assert v is not None assert v is not None
attn_output = torch.bmm(attn_probs, v) attn_output = torch.bmm(attn_probs, v)
...@@ -767,36 +746,13 @@ class Attention(nn.Module): ...@@ -767,36 +746,13 @@ class Attention(nn.Module):
attn_weights = None attn_weights = None
return attn_output, attn_weights return attn_output, attn_weights
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz): def _concat_saved_state(self, k, v, saved_state, static_kv, bsz) -> Tuple[Tensor]:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim) # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state: prev_K = saved_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim)
_prev_key = saved_state["prev_key"] prev_V = saved_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim)
assert _prev_key is not None new_K = prev_K if static_kv else torch.cat([prev_K, k], dim=1)
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) new_V = prev_V if static_kv else torch.cat([prev_V, v], dim=1)
if static_kv: return new_K, new_V
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
assert k is not None and v is not None
prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
if prev_key_padding_mask is not None:
if static_kv:
new_key_padding_mask = prev_key_padding_mask
else:
new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
else:
new_key_padding_mask = key_padding_mask
return k, v, new_key_padding_mask
class BartClassificationHead(nn.Module): class BartClassificationHead(nn.Module):
...@@ -1143,14 +1099,15 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1143,14 +1099,15 @@ class BartForConditionalGeneration(PretrainedBartModel):
def adjust_logits_during_generation(self, logits, cur_len, max_length): def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1 and self.config.force_bos_token_to_be_generated: if cur_len == 1 and self.config.force_bos_token_to_be_generated:
self._force_token_ids_generation(logits, self.config.bos_token_id) self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
elif cur_len == max_length - 1 and self.config.eos_token_id is not None: elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id) self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
return logits return logits
def _force_token_ids_generation(self, scores, token_id) -> None: @staticmethod
def _force_token_id_to_be_generated(scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf") scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
......
...@@ -52,5 +52,5 @@ class BlenderbotForConditionalGeneration(BartForConditionalGeneration): ...@@ -52,5 +52,5 @@ class BlenderbotForConditionalGeneration(BartForConditionalGeneration):
def adjust_logits_during_generation(self, logits, cur_len, max_length): def adjust_logits_during_generation(self, logits, cur_len, max_length):
logits[:, self.config.bos_token_id] = -torch.finfo(torch.float16).max # near infinity fp16 logits[:, self.config.bos_token_id] = -torch.finfo(torch.float16).max # near infinity fp16
if cur_len == max_length - 1 and self.config.eos_token_id is not None: if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id) self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
return logits return logits
...@@ -51,5 +51,5 @@ class MarianMTModel(BartForConditionalGeneration): ...@@ -51,5 +51,5 @@ class MarianMTModel(BartForConditionalGeneration):
def adjust_logits_during_generation(self, logits, cur_len, max_length): def adjust_logits_during_generation(self, logits, cur_len, max_length):
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token. logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
if cur_len == max_length - 1 and self.config.eos_token_id is not None: if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id) self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
return logits return logits
...@@ -21,6 +21,7 @@ from collections import OrderedDict ...@@ -21,6 +21,7 @@ from collections import OrderedDict
from .configuration_auto import ( from .configuration_auto import (
AlbertConfig, AlbertConfig,
AutoConfig, AutoConfig,
BartConfig,
BertConfig, BertConfig,
CamembertConfig, CamembertConfig,
CTRLConfig, CTRLConfig,
...@@ -51,6 +52,7 @@ from .modeling_tf_albert import ( ...@@ -51,6 +52,7 @@ from .modeling_tf_albert import (
TFAlbertForTokenClassification, TFAlbertForTokenClassification,
TFAlbertModel, TFAlbertModel,
) )
from .modeling_tf_bart import TFBartForConditionalGeneration
from .modeling_tf_bert import ( from .modeling_tf_bert import (
TFBertForMaskedLM, TFBertForMaskedLM,
TFBertForMultipleChoice, TFBertForMultipleChoice,
...@@ -206,6 +208,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( ...@@ -206,6 +208,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(T5Config, TFT5ForConditionalGeneration), (T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM), (DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForMaskedLM), (AlbertConfig, TFAlbertForMaskedLM),
(BartConfig, TFBartForConditionalGeneration),
(CamembertConfig, TFCamembertForMaskedLM), (CamembertConfig, TFCamembertForMaskedLM),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM), (XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(LongformerConfig, TFLongformerForMaskedLM), (LongformerConfig, TFLongformerForMaskedLM),
...@@ -256,7 +259,9 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( ...@@ -256,7 +259,9 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
] ]
) )
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict([(T5Config, TFT5ForConditionalGeneration)]) TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[(T5Config, TFT5ForConditionalGeneration), (BartConfig, TFBartForConditionalGeneration)]
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[ [
......
This diff is collapsed.
...@@ -717,7 +717,7 @@ BERT_START_DOCSTRING = r""" ...@@ -717,7 +717,7 @@ BERT_START_DOCSTRING = r"""
Args: Args:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration. Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. Check out the :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the model weights.
""" """
BERT_INPUTS_DOCSTRING = r""" BERT_INPUTS_DOCSTRING = r"""
......
...@@ -229,7 +229,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -229,7 +229,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
else: else:
logger.warning( logger.warning(
f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n" f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
f"If your task is similar to the task the model of the ckeckpoint was trained on, " f"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {tf_model.__class__.__name__} for predictions without further training." f"you can already use {tf_model.__class__.__name__} for predictions without further training."
) )
...@@ -383,7 +383,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F ...@@ -383,7 +383,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
else: else:
logger.warning( logger.warning(
f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n" f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
f"If your task is similar to the task the model of the ckeckpoint was trained on, " f"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {pt_model.__class__.__name__} for predictions without further training." f"you can already use {pt_model.__class__.__name__} for predictions without further training."
) )
......
...@@ -20,6 +20,7 @@ import copy ...@@ -20,6 +20,7 @@ import copy
import itertools import itertools
import math import math
import warnings import warnings
from typing import Tuple
import tensorflow as tf import tensorflow as tf
...@@ -594,7 +595,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -594,7 +595,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
training=False, training=False,
**kwargs, **kwargs,
): ) -> Tuple:
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
...@@ -699,7 +700,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -699,7 +700,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -1e9 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
...@@ -721,7 +722,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -721,7 +722,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if num_dims_encoder_attention_mask == 2: if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
...@@ -1417,7 +1418,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1417,7 +1418,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
"use_cache": use_cache, "use_cache": use_cache,
} }
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past, beam_idx) -> Tuple:
# if decoder past is not included in output # if decoder past is not included in output
# speedy decoding is disabled and no need to reorder # speedy decoding is disabled and no need to reorder
......
...@@ -136,8 +136,7 @@ class TFCausalLanguageModelingLoss: ...@@ -136,8 +136,7 @@ class TFCausalLanguageModelingLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE from_logits=True, reduction=tf.keras.losses.Reduction.NONE
) )
# make sure only labels that are not equal to -100 # make sure only labels that are not equal to -100 do not affect loss
# are taken into account as loss
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
......
...@@ -1945,11 +1945,6 @@ class SummarizationPipeline(Pipeline): ...@@ -1945,11 +1945,6 @@ class SummarizationPipeline(Pipeline):
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True" assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
assert len(documents) > 0, "Please provide a document to summarize" assert len(documents) > 0, "Please provide a document to summarize"
if self.framework == "tf" and "BartForConditionalGeneration" in self.model.__class__.__name__:
raise NotImplementedError(
"Tensorflow is not yet supported for Bart. Please consider using T5, e.g. `t5-base`"
)
prefix = self.model.config.prefix if self.model.config.prefix is not None else "" prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
if isinstance(documents[0], list): if isinstance(documents[0], list):
......
...@@ -212,6 +212,24 @@ class TFAutoModelWithLMHead: ...@@ -212,6 +212,24 @@ class TFAutoModelWithLMHead:
requires_tf(self) requires_tf(self)
class TFBartForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFBartModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
This diff is collapsed.
This diff is collapsed.
...@@ -302,7 +302,7 @@ class TFModelTesterMixin: ...@@ -302,7 +302,7 @@ class TFModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name) pt_model_class = getattr(transformers, pt_model_class_name)
config.output_hidden_states = True config.output_hidden_states = True
...@@ -472,10 +472,9 @@ class TFModelTesterMixin: ...@@ -472,10 +472,9 @@ class TFModelTesterMixin:
# Prepare our model # Prepare our model
model = model_class(config) model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights # Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
outputs = model(self._prepare_for_class(inputs_dict, model_class)) # build the model
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname)
...@@ -494,7 +493,9 @@ class TFModelTesterMixin: ...@@ -494,7 +493,9 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
outputs_dict = model(self._prepare_for_class(inputs_dict, model_class)) inputs = self._prepare_for_class(inputs_dict, model_class)
outputs_dict = model(inputs)
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
input_ids = inputs_keywords.pop("input_ids", None) input_ids = inputs_keywords.pop("input_ids", None)
...@@ -507,28 +508,18 @@ class TFModelTesterMixin: ...@@ -507,28 +508,18 @@ class TFModelTesterMixin:
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
decoder_seq_length = ( decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", self.model_tester.seq_length)
self.model_tester.decoder_seq_length encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
if hasattr(self.model_tester, "decoder_seq_length") decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
else self.model_tester.seq_length encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
)
encoder_seq_length = (
self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "encoder_seq_length")
else self.model_tester.seq_length
)
decoder_key_length = (
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else decoder_seq_length
)
encoder_key_length = (
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
inputs_dict["use_cache"] = False
config.output_hidden_states = False config.output_hidden_states = False
model = model_class(config) model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class)) model_inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(model_inputs)
attentions = [t.numpy() for t in outputs[-1]] attentions = [t.numpy() for t in outputs[-1]]
self.assertEqual(model.config.output_hidden_states, False) self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
......
...@@ -279,9 +279,8 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -279,9 +279,8 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in ["t5-small"]: model = TFT5Model.from_pretrained("t5-small")
model = TFT5Model.from_pretrained(model_name) self.assertIsNotNone(model)
self.assertIsNotNone(model)
@require_tf @require_tf
......
...@@ -21,7 +21,7 @@ FILL_MASK_FINETUNED_MODELS = ["sshleifer/tiny-distilroberta-base"] ...@@ -21,7 +21,7 @@ FILL_MASK_FINETUNED_MODELS = ["sshleifer/tiny-distilroberta-base"]
LARGE_FILL_MASK_FINETUNED_MODELS = ["distilroberta-base"] # @slow LARGE_FILL_MASK_FINETUNED_MODELS = ["distilroberta-base"] # @slow
SUMMARIZATION_FINETUNED_MODELS = ["sshleifer/bart-tiny-random", "patrickvonplaten/t5-tiny-random"] SUMMARIZATION_FINETUNED_MODELS = ["sshleifer/bart-tiny-random", "patrickvonplaten/t5-tiny-random"]
TF_SUMMARIZATION_FINETUNED_MODELS = ["patrickvonplaten/t5-tiny-random"] TF_SUMMARIZATION_FINETUNED_MODELS = ["sshleifer/bart-tiny-random", "patrickvonplaten/t5-tiny-random"]
TRANSLATION_FINETUNED_MODELS = [ TRANSLATION_FINETUNED_MODELS = [
("patrickvonplaten/t5-tiny-random", "translation_en_to_de"), ("patrickvonplaten/t5-tiny-random", "translation_en_to_de"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment