Unverified Commit abc573f5 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[TF Bart] Refactor TFBart (#9029)

* reorder file

* delete unnecesarry function

* make style

* save intermediate

* fix attention masks

* correct tf bart past key values

* solve merge conflict bug

* correct tensor dims

* save intermediate tf

* change attn layer

* fix typo re-order past

* inputs_embeds

* make fix copies

* finish tests

* fix graph mode

* appyl lysandres suggestions
parent 389aba34
...@@ -717,7 +717,7 @@ if is_tf_available(): ...@@ -717,7 +717,7 @@ if is_tf_available():
TFAutoModelForTokenClassification, TFAutoModelForTokenClassification,
TFAutoModelWithLMHead, TFAutoModelWithLMHead,
) )
from .models.bart import TFBartForConditionalGeneration, TFBartModel from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
from .models.bert import ( from .models.bert import (
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBertEmbeddings, TFBertEmbeddings,
......
...@@ -36,4 +36,4 @@ if is_torch_available(): ...@@ -36,4 +36,4 @@ if is_torch_available():
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
...@@ -215,10 +215,10 @@ class BartAttention(nn.Module): ...@@ -215,10 +215,10 @@ class BartAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
attn_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -274,14 +274,14 @@ class BartAttention(nn.Module): ...@@ -274,14 +274,14 @@ class BartAttention(nn.Module):
src_len, src_len,
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
if attn_mask is not None: if attention_mask is not None:
assert attn_mask.size() == ( assert attention_mask.size() == (
bsz, bsz,
1, 1,
tgt_len, tgt_len,
src_len, src_len,
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attn_mask.size()}" ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.softmax(attn_weights, dim=-1)
...@@ -335,23 +335,19 @@ class BartEncoderLayer(nn.Module): ...@@ -335,23 +335,19 @@ class BartEncoderLayer(nn.Module):
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = BartLayerNorm(self.embed_dim) self.final_layer_norm = BartLayerNorm(self.embed_dim)
def forward( def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
self, hidden_states: torch.Tensor, encoder_padding_mask: torch.Tensor, output_attentions: bool = False
):
""" """
Args: Args:
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (:obj:`bool`): Whether the base model outputs attentions. This requires the attentions tensor to be reshaped in this function.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
""" """
residual = hidden_states residual = hidden_states
if self.normalize_before: if self.normalize_before:
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.self_attn( hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attn_mask=encoder_padding_mask, output_attentions=output_attentions hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -405,24 +401,35 @@ class BartDecoderLayer(nn.Module): ...@@ -405,24 +401,35 @@ class BartDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
encoder_attn_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
attn_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[torch.Tensor] = False, output_attentions: Optional[torch.Tensor] = False,
): ):
"""
Args:
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`): Whether the base model outputs attentions. This requires the attentions tensor to be reshaped in this function.
"""
residual = hidden_states residual = hidden_states
if self.normalize_before: if self.normalize_before:
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple # add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attn_mask=attn_mask, attention_mask=attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
...@@ -443,7 +450,7 @@ class BartDecoderLayer(nn.Module): ...@@ -443,7 +450,7 @@ class BartDecoderLayer(nn.Module):
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attn_mask=encoder_attn_mask, attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -905,9 +912,9 @@ class BartDecoder(BartPretrainedModel): ...@@ -905,9 +912,9 @@ class BartDecoder(BartPretrainedModel):
# create causal mask # create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attn_mask = None combined_attention_mask = None
if input_shape[-1] > 1: if input_shape[-1] > 1:
attn_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(self.device)
...@@ -928,9 +935,9 @@ class BartDecoder(BartPretrainedModel): ...@@ -928,9 +935,9 @@ class BartDecoder(BartPretrainedModel):
# never mask leading token, even if it is pad # never mask leading token, even if it is pad
attention_mask[:, 0] = attention_mask[:, 1] attention_mask[:, 0] = attention_mask[:, 1]
if attention_mask is not None and attn_mask is not None: if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attn_mask = attn_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, past_key_values_length=past_key_values_length attention_mask, inputs_embeds.dtype, past_key_values_length=past_key_values_length
) )
...@@ -968,9 +975,9 @@ class BartDecoder(BartPretrainedModel): ...@@ -968,9 +975,9 @@ class BartDecoder(BartPretrainedModel):
hidden_states, layer_self_attn, present_key_value, layer_cross_attn = decoder_layer( hidden_states, layer_self_attn, present_key_value, layer_cross_attn = decoder_layer(
hidden_states, hidden_states,
encoder_hidden_states, attention_mask=combined_attention_mask,
encoder_attn_mask=encoder_attention_mask, encoder_hidden_states=encoder_hidden_states,
attn_mask=attn_mask, encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import math import math
import random import random
import warnings
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -49,11 +50,457 @@ from ...utils import logging ...@@ -49,11 +50,457 @@ from ...utils import logging
from .configuration_bart import BartConfig from .configuration_bart import BartConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "BartConfig" _CONFIG_FOR_DOC = "BartConfig"
_TOKENIZER_FOR_DOC = "BartTokenizer" _TOKENIZER_FOR_DOC = "BartTokenizer"
BART_START_DOCSTRING = r""" LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, eos_token_id: int):
shifted_input_ids = tf.cast(input_ids, tf.int32)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), eos_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
mask = tf.cast(mask, tf.float32)
if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = shape_list(mask)
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)), tf.float32)
if past_key_values_length > 0:
# concat fully attendend attention_mask to the beginning if `past_key_values` are used
expanded_mask = tf.concat(
[
tf.ones((bsz, 1, tgt_len, past_key_values_length), dtype=tf.float32),
expanded_mask,
],
axis=-1,
)
return (1.0 - expanded_mask) * LARGE_NEGATIVE
class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
"""
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
the forward function.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset, **kwargs):
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self.offset = offset
assert padding_idx is not None, "padding_idx cannot be None"
num_embeddings += offset
super().__init__(num_embeddings, embedding_dim, **kwargs)
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
positions = tf.range(
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
)
return super().call(positions + self.offset) # super object is not callable for some reason
class TFBartSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, **kwargs):
if embedding_dim % 2 != 0:
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
super().__init__(
num_positions,
embedding_dim,
**kwargs,
)
def build(self, input_shape: tf.TensorShape):
"""
Build shared token embedding layer Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
super().build(input_shape) # Instantiates self.weight so it can be loaded
weight: np.ndarray = self._init_weight(self.input_dim, self.output_dim)
self.set_weights([weight]) # overwrite self.weight to correct value
@staticmethod
def _init_weight(n_pos: int, dim: int):
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
"""
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
# index 0 is all zero
position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
# convert to tensor
table = tf.convert_to_tensor(position_enc, dtype=tf.float32)
tf.stop_gradient(table)
return table
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
positions = tf.range(
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
)
return super().call(positions)
class TFBartAttention(tf.keras.layers.Layer):
"""Multi-headed attention from "Attention Is All You Need"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = tf.keras.layers.Dropout(dropout)
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.is_decoder = is_decoder
self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
self.q_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
self.v_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
self.out_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
def call(
self,
hidden_states: tf.Tensor,
key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None,
training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = shape_list(hidden_states)
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = tf.concat([past_key_value[0], key_states], axis=2)
value_states = tf.concat([past_key_value[1], value_states], axis=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
key_states = tf.reshape(key_states, proj_shape)
value_states = tf.reshape(value_states, proj_shape)
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
if attention_mask is not None:
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
)
attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
attn_output = self.out_proj(attn_output)
attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
return attn_output, attn_weights, past_key_value
class TFBartEncoderLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model
self.self_attn = TFBartAttention(
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
"""
Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
"""
residual = hidden_states
if self.normalize_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
if not self.normalize_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
if self.normalize_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout(hidden_states, training=training)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
if not self.normalize_before:
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states, self_attn_weights
class TFBartDecoderLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model
self.self_attn = TFBartAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
name="self_attn",
is_decoder=True,
)
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.encoder_attn = TFBartAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
name="encoder_attn",
is_decoder=True,
)
self.encoder_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(
self,
hidden_states,
attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
"""
Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
"""
residual = hidden_states
if self.normalize_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
if not self.normalize_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
cross_attn_present_key_value = None
if encoder_hidden_states is not None:
residual = hidden_states
if self.normalize_before:
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value,
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
if not self.normalize_before:
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# Fully Connected
residual = hidden_states
if self.normalize_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout(hidden_states, training=training)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
if not self.normalize_before:
hidden_states = self.final_layer_norm(hidden_states)
return (
hidden_states,
self_attn_weights,
present_key_value,
)
class TFBartPretrainedModel(TFPreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
@property
def dummy_inputs(self):
pad_token = 1
input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32)
decoder_input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32)
dummy_inputs = {
"decoder_input_ids": decoder_input_ids,
"attention_mask": tf.math.not_equal(input_ids, pad_token),
"input_ids": input_ids,
}
return dummy_inputs
class TFPretrainedBartModel(TFBartPretrainedModel):
def __init_subclass__(self):
warnings.warn(
"The class `TFPretrainedBartModel` has been deprecated, please use `TFBartPretrainedModel` instead.",
FutureWarning,
)
BART_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.) embeddings, pruning heads etc.)
...@@ -75,7 +522,7 @@ BART_START_DOCSTRING = r""" ...@@ -75,7 +522,7 @@ BART_START_DOCSTRING = r"""
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in If you choose this second option, there are three possibilities you can use to gather all the input Tensors in
the first positional argument : the first positional argument :
- a single Tensor with :obj:`input_ids` only and nothing else: :obj:`model(inputs_ids)` - a single Tensor with :obj:`input_ids` only and nothing else: :obj:`model(input_ids)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
:obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])` :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring: - a dictionary with one or several input Tensors associated to the input names given in the docstring:
...@@ -88,7 +535,6 @@ BART_START_DOCSTRING = r""" ...@@ -88,7 +535,6 @@ BART_START_DOCSTRING = r"""
model weights. model weights.
""" """
BART_INPUTS_DOCSTRING = r""" BART_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`tf.Tensor` of shape :obj:`({0})`): input_ids (:obj:`tf.Tensor` of shape :obj:`({0})`):
...@@ -114,7 +560,7 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -114,7 +560,7 @@ BART_INPUTS_DOCSTRING = r"""
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
past_key_values (:obj:`Tuple[Dict[str: tf.Tensor]]` of length :obj:`config.n_layers`) past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
...@@ -130,170 +576,26 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -130,170 +576,26 @@ BART_INPUTS_DOCSTRING = r"""
more detail. more detail.
return_dict (:obj:`bool`, `optional`): return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.TFModelOutput` instead of a plain tuple. Whether or not to return a :class:`~transformers.file_utils.TFModelOutput` instead of a plain tuple.
training (:obj:`bool`, `optional`, defaults to :obj:`False`): training (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use the model in training mode (some modules like dropout modules have different Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation). behaviors between training and evaluation).
""" """
LARGE_NEGATIVE = -1e8
logger = logging.get_logger(__name__)
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
"""
mask = input_ids.ne(padding_idx).int()
incremental_indices = tf.cumsum(mask, axis=1).type_as(mask) * mask
return incremental_indices.long() + padding_idx
def causal_attention_mask(nd, ns, dtype):
"""
1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), -1,
ns-nd), but doesn't produce garbage on TPUs.
"""
i = tf.range(nd)[:, None]
j = tf.range(ns)
m = i < j - ns + nd
return tf.cast(m, dtype) * LARGE_NEGATIVE
def invert_mask(attention_mask: tf.Tensor):
"""Turns 1->0, 0->1, False->True, True-> False"""
tf.debugging.assert_rank(attention_mask, 2)
attention_mask = tf.cast(attention_mask, tf.bool)
ret = tf.math.logical_not(attention_mask) # dtype is tf.bool
return ret
class TFPretrainedBartModel(TFPreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
@property
def dummy_inputs(self):
pad_token = 1
input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32)
decoder_input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32)
dummy_inputs = {
"decoder_input_ids": decoder_input_ids,
"attention_mask": tf.math.not_equal(input_ids, pad_token),
"input_ids": input_ids,
}
return dummy_inputs
def _shift_right(self, input_ids):
# Should maybe be decoder_start_token_id. Change for torch and TF in one PR
position_0_id = self.config.eos_token_id
pad_token_id = self.config.pad_token_id
shifted_input_ids = tf.cast(input_ids, tf.int32)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), position_0_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
# Helper Functions, mostly for making masks
def make_padding_mask(input_ids, padding_idx=1):
"""True for pad tokens"""
padding_mask = tf.math.equal(input_ids, padding_idx) # bool tensor
return padding_mask
# Helper Modules
PAST_KV_DEPRECATION_WARNING = (
"The `past_key_value_states` argument is deprecated and will be removed in a future "
"version, use `past_key_values` instead."
)
class TFEncoderLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model
self.self_attn = TFAttention(
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, x, encoder_padding_mask, training=False):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
for t_tgt, t_src is excluded (or masked out), =0 means it is
included in attention
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, self_attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask)
tf.debugging.assert_equal(
shape_list(x),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}",
)
x = self.dropout(x, training=training)
x = residual + x
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout(x, training=training)
x = self.fc2(x)
x = self.dropout(x, training=training)
x = residual + x
if not self.normalize_before:
x = self.final_layer_norm(x)
return x, self_attn_weights
@keras_serializable
class TFBartEncoder(tf.keras.layers.Layer): class TFBartEncoder(tf.keras.layers.Layer):
# config_class = BartConfig config_class = BartConfig
""" """
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
:class:`TFEncoderLayer`. :class:`TFBartEncoderLayer`.
Args: Args:
config: BartConfig config: BartConfig
""" """
def __init__(self, config: BartConfig, embed_tokens: TFSharedEmbeddings, **kwargs): def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
self.layerdrop = config.encoder_layerdrop self.layerdrop = config.encoder_layerdrop
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
...@@ -302,20 +604,20 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -302,20 +604,20 @@ class TFBartEncoder(tf.keras.layers.Layer):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
if config.static_position_embeddings: if config.static_position_embeddings:
self.embed_positions = TFSinusoidalPositionalEmbedding( self.embed_positions = TFBartSinusoidalPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
config.d_model, config.d_model,
name="embed_positions", name="embed_positions",
) )
else: else:
self.embed_positions = TFLearnedPositionalEmbedding( self.embed_positions = TFBartLearnedPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
config.d_model, config.d_model,
self.padding_idx, self.padding_idx,
config.extra_pos_embeddings, config.extra_pos_embeddings,
name="embed_positions", name="embed_positions",
) )
self.layers = [TFEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = ( self.layernorm_embedding = (
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
if config.normalize_embedding if config.normalize_embedding
...@@ -330,203 +632,148 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -330,203 +632,148 @@ class TFBartEncoder(tf.keras.layers.Layer):
def call( def call(
self, self,
input_ids=None, input_ids=None,
inputs_embeds=None,
attention_mask=None, attention_mask=None,
output_attentions=False, output_attentions=None,
output_hidden_states=False, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
""" """
Args: Args:
input_ids (Tensor): tokens in the source language of shape input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
`(batch, src_len)` Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
attention_mask (Tensor): indicating which indices are padding tokens provide it.
Indices can be obtained using :class:`~transformers.BartTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
into associated vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
for more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
Returns: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
namedtuple: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
- **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` if inputs["inputs_embeds"] is None:
inputs_embeds = self.embed_tokens(inputs["input_ids"])
else:
inputs_embeds = inputs["inputs_embeds"]
inputs_embeds = inputs_embeds * self.embed_scale
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"])
- **encoder_states** (List[tf.Tensor]): all intermediate hidden states of shape `(src_len, batch,
embed_dim)`. Only populated if *output_hidden_states* is True.
- **all_attentions** (List[tf.Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
"""
# check attention mask and invert # check attention mask and invert
if attention_mask is not None: if inputs["attention_mask"] is not None:
assert ( # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask._rank() == 2 attention_mask = _expand_mask(inputs["attention_mask"])
), f"expected attention_mask._rank() to be a 2D tensor got {attention_mask._rank()}" else:
attention_mask = tf.cast(attention_mask, dtype=tf.float32) attention_mask = None
attention_mask = (1.0 - attention_mask) * LARGE_NEGATIVE
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale encoder_states = () if inputs["output_hidden_states"] else None
embed_pos = self.embed_positions(input_ids) all_attentions = () if inputs["output_attentions"] else None
x = inputs_embeds + embed_pos
x = self.layernorm_embedding(x)
x = self.dropout(x, training=training)
# B x T x C -> T x B x C
x = tf.transpose(x, perm=[1, 0, 2])
encoder_states = [] if output_hidden_states else None
all_attentions = () if output_attentions else None
# encoder layers # encoder layers
for encoder_layer in self.layers: for encoder_layer in self.layers:
if output_hidden_states: if inputs["output_hidden_states"]:
encoder_states.append(x) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if training and (dropout_probability < self.layerdrop): # skip the layer if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
attn = None continue
else:
x, attn = encoder_layer(x, attention_mask) hidden_states, attn = encoder_layer(hidden_states, attention_mask)
if output_attentions: if inputs["output_attentions"]:
all_attentions += (attn,) all_attentions += (attn,)
if self.layer_norm: if self.layer_norm:
x = self.layer_norm(x) hidden_states = self.layer_norm(hidden_states)
if output_hidden_states: if inputs["output_hidden_states"]:
encoder_states.append(x) encoder_states = encoder_states + (hidden_states,)
encoder_states = [tf.transpose(hidden_state, perm=(1, 0, 2)) for hidden_state in encoder_states]
x = tf.transpose(x, perm=(1, 0, 2))
if not return_dict:
return tuple(v for v in [x, encoder_states, all_attentions] if v is not None)
return TFBaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
class TFDecoderLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model
self.self_attn = TFAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
name="self_attn",
)
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.encoder_attn = TFAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
encoder_decoder_attention=True,
name="encoder_attn",
)
self.encoder_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(
self,
x,
encoder_hidden_states: tf.Tensor,
encoder_attn_mask=None,
layer_state=None,
causal_mask=None,
decoder_padding_mask=None,
training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Dict[str, tf.Tensor]]:
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attn_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``.
need_attn_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
Tuple containing, encoded output of shape `(seq_len, batch, embed_dim)`, self_attn_weights, layer_state if not inputs["return_dict"]:
""" return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
residual = x # Make a copy of the input tensor to add later. return TFBaseModelOutput(
if layer_state is None: last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
layer_state = {}
if self.normalize_before:
x = self.self_attn_layer_norm(x)
# next line mutates layer state and we need a copy of it
x, self_attn_weights = self.self_attn(
query=x,
key=x,
layer_state=layer_state,
attn_mask=causal_mask,
key_padding_mask=decoder_padding_mask,
)
x = self.dropout(x, training=training)
x = residual + x
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
# Cross-Attention Block
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x, _ = self.encoder_attn(
query=x,
key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state
) )
x = self.dropout(x, training=training)
x = residual + x
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
# Fully Connected
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout(x, training=training)
x = self.fc2(x)
x = self.dropout(x, training=training)
x = residual + x
if not self.normalize_before:
x = self.final_layer_norm(x)
return (
x,
self_attn_weights,
layer_state,
) # just self_attn weights for now, following t5, layer_state = cache for decoding
@keras_serializable
class TFBartDecoder(tf.keras.layers.Layer): class TFBartDecoder(tf.keras.layers.Layer):
config_class = BartConfig
""" """
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`TFDecoderLayer` Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`TFBartDecoderLayer`
Args: Args:
config: BartConfig config: BartConfig
embed_tokens: output embedding embed_tokens: output embedding
""" """
def __init__(self, config: BartConfig, embed_tokens, **kwargs): def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.layerdrop = config.decoder_layerdrop self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.layerdrop = config.decoder_layerdrop
if config.static_position_embeddings: if config.static_position_embeddings:
self.embed_positions = TFSinusoidalPositionalEmbedding( self.embed_positions = TFBartSinusoidalPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
config.d_model, config.d_model,
name="embed_positions", name="embed_positions",
) )
else: else:
self.embed_positions = TFLearnedPositionalEmbedding( self.embed_positions = TFBartLearnedPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
config.d_model, config.d_model,
self.padding_idx, self.padding_idx,
config.extra_pos_embeddings, config.extra_pos_embeddings,
name="embed_positions", name="embed_positions",
) )
self.layers = [TFDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] self.layers = [TFBartDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
self.layernorm_embedding = ( self.layernorm_embedding = (
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
if config.normalize_embedding if config.normalize_embedding
...@@ -543,322 +790,197 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -543,322 +790,197 @@ class TFBartDecoder(tf.keras.layers.Layer):
def call( def call(
self, self,
input_ids, input_ids=None,
encoder_hidden_states, inputs_embeds=None,
encoder_padding_mask, attention_mask=None,
decoder_padding_mask, encoder_hidden_states=None,
decoder_causal_mask, encoder_attention_mask=None,
decoder_cached_states=None, past_key_values=None,
use_cache=False, use_cache=None,
output_attentions=False, output_attentions=None,
output_hidden_states=False, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
# check attention mask and invert r"""
if encoder_padding_mask is not None: Args:
encoder_padding_mask = invert_mask(encoder_padding_mask) input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using :class:`~transformers.BartTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`):
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last
:obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of
shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size,
sequence_length)`.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
into associated vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
for more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
past_key_values_length = (
inputs["past_key_values"][0][0].shape[2] if inputs["past_key_values"] is not None else 0
)
# embed positions # embed positions
positions = self.embed_positions(input_ids, use_cache=(use_cache and decoder_cached_states is not None)) positions = self.embed_positions(input_shape, past_key_values_length)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(inputs["input_ids"])
else:
inputs_embeds = inputs["inputs_embeds"]
if use_cache and decoder_cached_states is not None: hidden_states = inputs_embeds * self.embed_scale
input_ids = input_ids[:, -1:]
positions = positions[:, -1:]
x = self.embed_tokens(input_ids) * self.embed_scale # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if self.do_blenderbot_90_layernorm: combined_attention_mask = None
x = self.layernorm_embedding(x) + positions if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1:
attention_mask = tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
)
else: else:
x = self.layernorm_embedding(x + positions) attention_mask = tf.ones(input_shape, dtype=tf.int32)
x = self.dropout(x, training=training)
# Convert to Bart output format: (BS, seq_len, model_dim) -> (seq_len, BS, model_dim) if attention_mask is not None and combined_attention_mask is not None:
x = tf.transpose(x, perm=(1, 0, 2)) # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
assert len(shape_list(encoder_hidden_states)) == 3, "encoder_hidden_states must be a 3D tensor" combined_attention_mask = combined_attention_mask + _expand_mask(
encoder_hidden_states = tf.transpose(encoder_hidden_states, perm=(1, 0, 2)) attention_mask, past_key_values_length=past_key_values_length
)
encoder_hidden_states = inputs["encoder_hidden_states"]
if encoder_hidden_states is not None and inputs["encoder_attention_mask"] is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1])
if self.do_blenderbot_90_layernorm:
hidden_states = self.layernorm_embedding(hidden_states) + positions
else:
hidden_states = self.layernorm_embedding(hidden_states + positions)
hidden_states = self.dropout(hidden_states, training=inputs["training"])
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = ()
all_self_attns = () all_self_attns = ()
next_decoder_cache = [] present_key_values = ()
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if inputs["output_hidden_states"]:
all_hidden_states += (x,) all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if training and (dropout_probability < self.layerdrop):
if inputs["training"] and (dropout_probability < self.layerdrop):
continue continue
layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
x, layer_self_attn, layer_past = decoder_layer( hidden_states, layer_self_attn, present_key_value = decoder_layer(
x, hidden_states,
encoder_hidden_states, attention_mask=combined_attention_mask,
encoder_attn_mask=encoder_padding_mask, encoder_hidden_states=encoder_hidden_states,
decoder_padding_mask=decoder_padding_mask, encoder_attention_mask=encoder_attention_mask,
layer_state=layer_state, past_key_value=past_key_value,
causal_mask=decoder_causal_mask,
) )
if use_cache: if inputs["use_cache"]:
next_decoder_cache.append(layer_past.copy()) present_key_values += (present_key_value,)
if output_attentions: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if self.layer_norm is not None: # same as if config.add_final_layer_norm if self.layer_norm is not None: # same as if config.add_final_layer_norm
x = self.layer_norm(x) hidden_states = self.layer_norm(hidden_states)
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
if output_hidden_states: if inputs["output_hidden_states"]:
all_hidden_states += (x,) all_hidden_states += (hidden_states,)
# T x B x C -> B x T x C
all_hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in all_hidden_states)
else: else:
all_hidden_states = None all_hidden_states = None
all_self_attns = list(all_self_attns) if output_attentions else None
x = tf.transpose(x, perm=(1, 0, 2)) all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
encoder_hidden_states = tf.transpose(encoder_hidden_states, perm=(1, 0, 2)) # could maybe be avoided.
next_cache = (encoder_hidden_states, next_decoder_cache) if use_cache else None present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
if not return_dict:
return x, next_cache, all_hidden_states, all_self_attns if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPast(
last_hidden_state=x, last_hidden_state=hidden_states,
past_key_values=next_cache, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
) )
def _reorder_buffer(attn_cache, new_order):
for k, input_buffer_k in attn_cache.items():
if input_buffer_k is not None:
attn_cache[k] = tf.gather(input_buffer_k, new_order, axis=0)
return attn_cache
class TFAttention(tf.keras.layers.Layer):
"""Multi-headed attention from "Attention Is All You Need"""
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
encoder_decoder_attention=False, # otherwise self_attention
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = tf.keras.layers.Dropout(dropout)
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.encoder_decoder_attention = encoder_decoder_attention
self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
self.q_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
self.v_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
self.out_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
def _shape(self, tensor: tf.Tensor, dim_0, bsz) -> tf.Tensor:
reshaped_T_B_D = tf.reshape(tensor, (dim_0, bsz * self.num_heads, self.head_dim))
return tf.transpose(reshaped_T_B_D, perm=(1, 0, 2))
def call(
self,
query: tf.Tensor,
key: tf.Tensor,
key_padding_mask: Optional[tf.Tensor] = None,
layer_state: Optional[Dict[str, tf.Tensor]] = None,
attn_mask: Optional[tf.Tensor] = None,
training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""
Input shape: Time(SeqLen) x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s.
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the attention from looking forward in time
(default: None).
"""
static_kv = self.encoder_decoder_attention # value=key=encoder_hidden_states,
tgt_len, bsz, embed_dim = shape_list(query)
assert (
embed_dim == self.embed_dim
), f"query must be shaped {(tgt_len, bsz, self.embed_dim)} got {shape_list(query)}"
# get here for encoder decoder cause of static_kv
if layer_state is not None: # get the last k and v for reuse
saved_state = layer_state.get(self.cache_key, {})
if "prev_key" in saved_state:
# previous time steps are cached - no need to recompute key and value if they are static
if static_kv:
key = None
else:
# this branch is hit by encoder
saved_state = None
# Project query key values using weights q_proj, k_proj, v_proj
q = self.q_proj(query) * self.scaling
if static_kv and key is None: # cross-attention with cache
k = v = None
elif static_kv and key is not None: # cross-attention no prev_key found in cache
k = self.k_proj(key)
v = self.v_proj(key)
else: # self-attention
k = self.k_proj(query)
v = self.v_proj(query)
# Reshape
q = self._shape(q, tgt_len, bsz)
if k is not None:
k = self._shape(k, -1, bsz)
v = self._shape(v, -1, bsz)
if saved_state: # read from cache
k, v = self._concat_saved_state(k, v, saved_state, static_kv, bsz)
if layer_state is not None: # Write to cache every decoder call
cached_shape = (bsz, self.num_heads, -1, self.head_dim) # bsz must be first for reorder_cache
layer_state[self.cache_key] = dict(
prev_key=tf.reshape(k, cached_shape), prev_value=tf.reshape(v, cached_shape)
)
# Compute multi-headed attention
src_len = shape_list(k)[1]
attn_weights = tf.matmul(q, k, transpose_b=True) # shape (bsz * self.num_heads, tgt_len, src_len)
if attn_mask is not None:
assert attn_mask.dtype == tf.float32, f"expected dtype tf.float32 got {attn_mask.dtype}"
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attn_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
if key_padding_mask is not None: # don't attend to padding symbols
attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
if key_padding_mask.dtype == tf.bool:
key_padding_mask = tf.cast(key_padding_mask, attn_weights.dtype) * -1e9
extended_mask = tf.expand_dims(tf.expand_dims(key_padding_mask, 1), 2)
attn_weights = attn_weights + extended_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, v) # shape: (bsz * self.num_heads, tgt_len, self.head_dim)
attn_output = tf.transpose(attn_output, perm=(1, 0, 2))
attn_output = tf.reshape(attn_output, (tgt_len, bsz, embed_dim))
attn_output = self.out_proj(attn_output)
attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
return attn_output, attn_weights
def _concat_saved_state(self, k, v, saved_state, static_kv, bsz) -> Tuple[tf.Tensor]:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
prev_key = tf.reshape(saved_state["prev_key"], (bsz * self.num_heads, -1, self.head_dim))
k = prev_key if static_kv else tf.concat([prev_key, k], axis=1)
prev_value = tf.reshape(saved_state["prev_value"], (bsz * self.num_heads, -1, self.head_dim))
v = prev_value if static_kv else tf.concat([prev_value, v], axis=1)
return k, v
class TFLearnedPositionalEmbedding(TFSharedEmbeddings):
"""
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
the forward function.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset, **kwargs):
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self.offset = offset
assert padding_idx is not None, "padding_idx cannot be None"
num_embeddings += offset
super().__init__(num_embeddings, embedding_dim, **kwargs)
def call(self, input_ids: tf.Tensor, use_cache=False):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = shape_list(input_ids)[:2]
if use_cache:
positions = tf.fill((1, 1), seq_len - 1)
else:
# starts at 0, ends at 1-seq_len
positions = tf.range(0, seq_len, delta=1, dtype=tf.int32, name="range")
return super().call(positions + self.offset) # super object is not callable for some reason
class TFSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions, embedding_dim, **kwargs):
if embedding_dim % 2 != 0:
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
super().__init__(
num_positions,
embedding_dim,
**kwargs,
)
def build(self, input_shape):
"""
Build shared token embedding layer Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
super().build(input_shape) # Instantiates self.weight so it can be loaded
weight: np.ndarray = self._init_weight(self.input_dim, self.output_dim)
self.set_weights([weight]) # overwrite self.weight to correct value
@staticmethod
def _init_weight(n_pos, dim):
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
"""
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
# index 0 is all zero
position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
# convert to tensor
table = tf.convert_to_tensor(position_enc, dtype=tf.float32)
tf.stop_gradient(table)
return table
def call(self, input_ids, use_cache=False):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = shape_list(input_ids)[:2]
if use_cache:
positions = tf.fill((1, 1), seq_len - 1)
else:
# starts at 0, ends at 1-seq_len
positions = tf.range(0, seq_len, delta=1, dtype=tf.int32, name="range")
return super().call(positions)
# Public API
@add_start_docstrings( @add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.", "The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
@keras_serializable @keras_serializable
class TFBartModel(TFPretrainedBartModel): class TFBartModel(TFBartPretrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
def __init__(self, config: BartConfig, *inputs, **kwargs): def __init__(self, config: BartConfig, *inputs, **kwargs):
...@@ -876,28 +998,8 @@ class TFBartModel(TFPretrainedBartModel): ...@@ -876,28 +998,8 @@ class TFBartModel(TFPretrainedBartModel):
self.encoder = TFBartEncoder(config, embed_tokens, name="encoder") self.encoder = TFBartEncoder(config, embed_tokens, name="encoder")
self.decoder = TFBartDecoder(config, embed_tokens, name="decoder") self.decoder = TFBartDecoder(config, embed_tokens, name="decoder")
def _prepare_bart_decoder_inputs( def get_decoder(self):
self, return self.decoder
inputs,
decoder_input_ids=None,
decoder_attn_mask=None,
mask_dtype=None,
):
"""
Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if none are provided.
This mimics the default behavior in fairseq. To override it pass in masks.
"""
pad_token_id = self.config.pad_token_id
if decoder_input_ids is None:
decoder_input_ids = self._shift_right(inputs)
bsz, tgt_len = shape_list(decoder_input_ids)[:2]
if decoder_attn_mask is None:
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
else:
decoder_padding_mask = invert_mask(decoder_attn_mask)
causal_lm_mask = causal_attention_mask(tgt_len, tgt_len, mask_dtype)
return decoder_input_ids, decoder_padding_mask, causal_lm_mask
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -908,12 +1010,14 @@ class TFBartModel(TFPretrainedBartModel): ...@@ -908,12 +1010,14 @@ class TFBartModel(TFPretrainedBartModel):
) )
def call( def call(
self, self,
input_ids, input_ids=None,
attention_mask=None, attention_mask=None,
decoder_input_ids=None, # BAD DEFAULT LEFT FOR CONSISTENT SIGNATURE decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -930,6 +1034,8 @@ class TFBartModel(TFPretrainedBartModel): ...@@ -930,6 +1034,8 @@ class TFBartModel(TFPretrainedBartModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -938,7 +1044,7 @@ class TFBartModel(TFPretrainedBartModel): ...@@ -938,7 +1044,7 @@ class TFBartModel(TFPretrainedBartModel):
kwargs_call=kwargs, kwargs_call=kwargs,
) )
if inputs["decoder_input_ids"] is None: if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None:
inputs["use_cache"] = False inputs["use_cache"] = False
inputs["output_hidden_states"] = ( inputs["output_hidden_states"] = (
...@@ -947,19 +1053,16 @@ class TFBartModel(TFPretrainedBartModel): ...@@ -947,19 +1053,16 @@ class TFBartModel(TFPretrainedBartModel):
else self.config.output_hidden_states else self.config.output_hidden_states
) )
if not use_cache or past_key_values is None: if inputs["decoder_input_ids"] is None and inputs["input_ids"] is not None:
inputs["decoder_input_ids"], decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs( inputs["decoder_input_ids"] = shift_tokens_right(
inputs["input_ids"], inputs["input_ids"], self.config.pad_token_id, self.config.eos_token_id
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attn_mask=inputs["decoder_attention_mask"],
mask_dtype=self.shared.dtype,
) )
else:
decoder_padding_mask, causal_mask = None, None
if inputs["encoder_outputs"] is None: if inputs["encoder_outputs"] is None:
inputs["encoder_outputs"] = self.encoder( inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"], return_dict=inputs["return_dict"],
...@@ -978,11 +1081,11 @@ class TFBartModel(TFPretrainedBartModel): ...@@ -978,11 +1081,11 @@ class TFBartModel(TFPretrainedBartModel):
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
inputs["decoder_input_ids"], inputs["decoder_input_ids"],
inputs["encoder_outputs"][0], attention_mask=decoder_attention_mask,
inputs["attention_mask"], encoder_hidden_states=inputs["encoder_outputs"][0],
decoder_padding_mask, encoder_attention_mask=inputs["attention_mask"],
decoder_causal_mask=causal_mask, past_key_values=inputs["past_key_values"],
decoder_cached_states=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
...@@ -1017,7 +1120,7 @@ class TFBartModel(TFPretrainedBartModel): ...@@ -1017,7 +1120,7 @@ class TFBartModel(TFPretrainedBartModel):
"The BART Model with a language modeling head. Can be used for summarization.", "The BART Model with a language modeling head. Can be used for summarization.",
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class TFBartForConditionalGeneration(TFPretrainedBartModel): class TFBartForConditionalGeneration(TFBartPretrainedModel):
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight", r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight",
...@@ -1032,6 +1135,9 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1032,6 +1135,9 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
) )
def get_decoder(self):
return self.model.decoder
def resize_token_embeddings(self, new_num_tokens): def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens) super().resize_token_embeddings(new_num_tokens=new_num_tokens)
...@@ -1041,12 +1147,11 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1041,12 +1147,11 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
num_tokens_to_copy = min(self.final_logits_bias.shape[0], new_num_tokens) num_tokens_to_copy = min(self.final_logits_bias.shape[0], new_num_tokens)
init_bias = tf.zeros((new_num_tokens,)) init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy] init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy]
name = self.name + "/final_logits_bias"
self.final_logits_bias = self.add_weight( self.final_logits_bias = self.add_weight(
shape=(1, new_num_tokens), shape=(1, new_num_tokens),
initializer="zeros", initializer="zeros",
trainable=False, trainable=False,
name=name, name="final_logits_bias",
) )
self.final_logits_bias.assign(init_bias) self.final_logits_bias.assign(init_bias)
...@@ -1054,12 +1159,14 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1054,12 +1159,14 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
input_ids, input_ids=None,
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -1094,6 +1201,8 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1094,6 +1201,8 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -1106,7 +1215,9 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1106,7 +1215,9 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
if inputs["labels"] is not None: if inputs["labels"] is not None:
inputs["use_cache"] = False inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None: if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = self._shift_right(inputs["labels"]) inputs["decoder_input_ids"] = shift_tokens_right(
inputs["labels"], self.config.pad_token_id, self.config.eos_token_id
)
outputs = self.model( outputs = self.model(
inputs["input_ids"], inputs["input_ids"],
...@@ -1115,6 +1226,8 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1115,6 +1226,8 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
...@@ -1138,23 +1251,28 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1138,23 +1251,28 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
) )
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache=True, **kwargs) -> Dict: def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1: if len(past) == 1:
assert isinstance(past[0], tf.Tensor) assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
decoder_cached_states = None past_key_values = None
else: else:
assert len(past) == 2 assert (
encoder_outputs, decoder_cached_states = past len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple): if isinstance(encoder_outputs, tuple):
assert isinstance(encoder_outputs[0], tf.Tensor) assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0]) encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor): elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs) encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert ( assert (
decoder_cached_states past_key_values
), f"decoder cached states must be truthy. got {decoder_cached_states} from the 2nd element of past" ), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance( assert isinstance(
encoder_outputs, TFBaseModelOutput encoder_outputs, TFBaseModelOutput
...@@ -1162,7 +1280,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1162,7 +1280,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": decoder_cached_states, "past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
...@@ -1170,18 +1288,17 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1170,18 +1288,17 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
assert len(past) == 2 if len(past) == 1:
(encoder_out, decoder_cached_states) = past return past
reordered_past = []
for layer_past in decoder_cached_states: past_key_values = past[1]
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = { reordered_past = ()
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() for layer_past_key_values in past_key_values:
} reordered_past += (
reordered_past.append(layer_past_new) tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values),
)
past = (encoder_out, reordered_past) return (past[0], reordered_past)
return past
def adjust_logits_during_generation(self, logits, cur_len, max_length): def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1 and self.config.force_bos_token_to_be_generated: if cur_len == 1 and self.config.force_bos_token_to_be_generated:
......
...@@ -1305,14 +1305,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1305,14 +1305,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# get decoder inputs from shifting lm labels to the right # get decoder inputs from shifting lm labels to the right
inputs["decoder_input_ids"] = self._shift_right(inputs["labels"]) inputs["decoder_input_ids"] = self._shift_right(inputs["labels"])
# If decoding with past key value states, only the last tokens
# should be given as an input
if inputs["past_key_values"] is not None:
if inputs["decoder_input_ids"] is not None:
inputs["decoder_input_ids"] = inputs["decoder_input_ids"][:, -1:]
if inputs["decoder_inputs_embeds"] is not None:
inputs["decoder_inputs_embeds"] = inputs["decoder_inputs_embeds"][:, -1:]
# Decode # Decode
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
inputs["decoder_input_ids"], inputs["decoder_input_ids"],
......
...@@ -256,6 +256,15 @@ class TFBartModel: ...@@ -256,6 +256,15 @@ class TFBartModel:
requires_tf(self) requires_tf(self)
class TFBartPretrainedModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -207,7 +207,7 @@ class BartModelTester: ...@@ -207,7 +207,7 @@ class BartModelTester:
@require_torch @require_torch
class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering) (BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering)
if is_torch_available() if is_torch_available()
......
...@@ -30,7 +30,7 @@ if is_tf_available(): ...@@ -30,7 +30,7 @@ if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers import TFBartForConditionalGeneration, TFBartModel from transformers import TFBartForConditionalGeneration, TFBartModel
from transformers.models.bart.modeling_tf_bart import TFSinusoidalPositionalEmbedding from transformers.models.bart.modeling_tf_bart import TFBartSinusoidalPositionalEmbedding
@require_tf @require_tf
...@@ -85,6 +85,38 @@ class TFBartModelTester: ...@@ -85,6 +85,38 @@ class TFBartModelTester:
inputs_dict = prepare_bart_inputs_dict(config, input_ids) inputs_dict = prepare_bart_inputs_dict(config, input_ids)
return config, inputs_dict return config, inputs_dict
def check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = TFBartModel(config=config).get_decoder()
input_ids = inputs_dict["input_ids"]
input_ids = input_ids[:1, :]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids)[0]
output_from_past = model(next_tokens, past_key_values=past_key_values)[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def prepare_bart_inputs_dict( def prepare_bart_inputs_dict(
config, config,
...@@ -114,9 +146,9 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -114,9 +146,9 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
def test_inputs_embeds(self): def test_decoder_model_past_large_inputs(self):
# inputs_embeds not supported config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
pass self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -285,13 +317,11 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase): ...@@ -285,13 +317,11 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
model = self.xsum_1_1_model model = self.xsum_1_1_model
assert model.model.decoder.embed_tokens._layer == model.model.shared assert model.model.decoder.embed_tokens._layer == model.model.shared
ARTICLE = 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.' ARTICLE = 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.'
EXPECTED = " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court."
dct = self.tok(ARTICLE, return_tensors="tf") dct = self.tok(ARTICLE, return_tensors="tf")
generated_ids = model.generate(**dct, num_beams=4) generated_ids = model.generate(**dct, num_beams=4)
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0] result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert ( assert result == EXPECTED
result
== " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court."
)
def test_xsum_1_1_batch_generation(self): def test_xsum_1_1_batch_generation(self):
batch = self.tok( batch = self.tok(
...@@ -325,7 +355,6 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase): ...@@ -325,7 +355,6 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
truncation=True, truncation=True,
) )
features = self.xsum_1_1_model.get_encoder()(**batch).last_hidden_state features = self.xsum_1_1_model.get_encoder()(**batch).last_hidden_state
import numpy as np
expected = np.array([[-0.0828, -0.0251, -0.0674], [0.1277, 0.3311, -0.0255], [0.2613, -0.0840, -0.2763]]) expected = np.array([[-0.0828, -0.0251, -0.0674], [0.1277, 0.3311, -0.0255], [0.2613, -0.0840, -0.2763]])
assert np.allclose(features[0, :3, :3].numpy(), expected, atol=1e-3) assert np.allclose(features[0, :3, :3].numpy(), expected, atol=1e-3)
...@@ -340,16 +369,14 @@ class TestTFSinusoidalPositionalEmbeddings(unittest.TestCase): ...@@ -340,16 +369,14 @@ class TestTFSinusoidalPositionalEmbeddings(unittest.TestCase):
] ]
def test_positional_emb_cache_logic(self): def test_positional_emb_cache_logic(self):
input_ids = _long_tensor([[4, 10]]) emb1 = TFBartSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6)
emb1 = TFSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6) no_cache = emb1((4, 10), past_key_values_length=0)
no_cache = emb1(input_ids, use_cache=False) yes_cache = emb1((4, 10), past_key_values_length=2)
yes_cache = emb1(input_ids, use_cache=True) self.assertTrue(no_cache.shape == yes_cache.shape == (10, 6))
self.assertEqual((1, 1, 6), yes_cache.shape) # extra dim to allow broadcasting, feel free to delete! self.assertListEqual(no_cache[2:].numpy().tolist(), yes_cache[:-2].numpy().tolist())
np.testing.assert_almost_equal(no_cache[-1].numpy(), yes_cache[0][0].numpy())
def test_positional_emb_weights_against_marian(self): def test_positional_emb_weights_against_marian(self):
emb1 = TFSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512) emb1 = TFBartSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512)
emb1.build(None) emb1.build(None)
weights = emb1.embeddings.numpy() weights = emb1.embeddings.numpy()
for i, (expected_weight, actual_weight) in enumerate(zip(self.desired_weights, weights)): for i, (expected_weight, actual_weight) in enumerate(zip(self.desired_weights, weights)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment