Unverified Commit abc573f5 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[TF Bart] Refactor TFBart (#9029)

* reorder file

* delete unnecesarry function

* make style

* save intermediate

* fix attention masks

* correct tf bart past key values

* solve merge conflict bug

* correct tensor dims

* save intermediate tf

* change attn layer

* fix typo re-order past

* inputs_embeds

* make fix copies

* finish tests

* fix graph mode

* appyl lysandres suggestions
parent 389aba34
......@@ -717,7 +717,7 @@ if is_tf_available():
TFAutoModelForTokenClassification,
TFAutoModelWithLMHead,
)
from .models.bart import TFBartForConditionalGeneration, TFBartModel
from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
from .models.bert import (
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBertEmbeddings,
......
......@@ -36,4 +36,4 @@ if is_torch_available():
)
if is_tf_available():
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
......@@ -215,10 +215,10 @@ class BartAttention(nn.Module):
def forward(
self,
hidden_states,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attn_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
......@@ -274,14 +274,14 @@ class BartAttention(nn.Module):
src_len,
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
if attn_mask is not None:
assert attn_mask.size() == (
if attention_mask is not None:
assert attention_mask.size() == (
bsz,
1,
tgt_len,
src_len,
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attn_mask.size()}"
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
......@@ -335,23 +335,19 @@ class BartEncoderLayer(nn.Module):
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = BartLayerNorm(self.embed_dim)
def forward(
self, hidden_states: torch.Tensor, encoder_padding_mask: torch.Tensor, output_attentions: bool = False
):
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
"""
Args:
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (:obj:`torch.FloatTensor`): attention mask of size
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
output_attentions (:obj:`bool`): Whether the base model outputs attentions. This requires the attentions tensor to be reshaped in this function.
"""
residual = hidden_states
if self.normalize_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attn_mask=encoder_padding_mask, output_attentions=output_attentions
hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
......@@ -405,24 +401,35 @@ class BartDecoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_attn_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attn_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[torch.Tensor] = False,
):
"""
Args:
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`): Whether the base model outputs attentions. This requires the attentions tensor to be reshaped in this function.
"""
residual = hidden_states
if self.normalize_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attn_mask=attn_mask,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
......@@ -443,7 +450,7 @@ class BartDecoderLayer(nn.Module):
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attn_mask=encoder_attn_mask,
attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
......@@ -905,9 +912,9 @@ class BartDecoder(BartPretrainedModel):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attn_mask = None
combined_attention_mask = None
if input_shape[-1] > 1:
attn_mask = _make_causal_mask(
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
......@@ -928,9 +935,9 @@ class BartDecoder(BartPretrainedModel):
# never mask leading token, even if it is pad
attention_mask[:, 0] = attention_mask[:, 1]
if attention_mask is not None and attn_mask is not None:
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attn_mask = attn_mask + _expand_mask(
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, past_key_values_length=past_key_values_length
)
......@@ -968,9 +975,9 @@ class BartDecoder(BartPretrainedModel):
hidden_states, layer_self_attn, present_key_value, layer_cross_attn = decoder_layer(
hidden_states,
encoder_hidden_states,
encoder_attn_mask=encoder_attention_mask,
attn_mask=attn_mask,
attention_mask=combined_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
......
......@@ -1305,14 +1305,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# get decoder inputs from shifting lm labels to the right
inputs["decoder_input_ids"] = self._shift_right(inputs["labels"])
# If decoding with past key value states, only the last tokens
# should be given as an input
if inputs["past_key_values"] is not None:
if inputs["decoder_input_ids"] is not None:
inputs["decoder_input_ids"] = inputs["decoder_input_ids"][:, -1:]
if inputs["decoder_inputs_embeds"] is not None:
inputs["decoder_inputs_embeds"] = inputs["decoder_inputs_embeds"][:, -1:]
# Decode
decoder_outputs = self.decoder(
inputs["decoder_input_ids"],
......
......@@ -256,6 +256,15 @@ class TFBartModel:
requires_tf(self)
class TFBartPretrainedModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -207,7 +207,7 @@ class BartModelTester:
@require_torch
class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering)
if is_torch_available()
......
......@@ -30,7 +30,7 @@ if is_tf_available():
import tensorflow as tf
from transformers import TFBartForConditionalGeneration, TFBartModel
from transformers.models.bart.modeling_tf_bart import TFSinusoidalPositionalEmbedding
from transformers.models.bart.modeling_tf_bart import TFBartSinusoidalPositionalEmbedding
@require_tf
......@@ -85,6 +85,38 @@ class TFBartModelTester:
inputs_dict = prepare_bart_inputs_dict(config, input_ids)
return config, inputs_dict
def check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = TFBartModel(config=config).get_decoder()
input_ids = inputs_dict["input_ids"]
input_ids = input_ids[:1, :]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids)[0]
output_from_past = model(next_tokens, past_key_values=past_key_values)[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def prepare_bart_inputs_dict(
config,
......@@ -114,9 +146,9 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
def test_config(self):
self.config_tester.run_common_tests()
def test_inputs_embeds(self):
# inputs_embeds not supported
pass
def test_decoder_model_past_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -285,13 +317,11 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
model = self.xsum_1_1_model
assert model.model.decoder.embed_tokens._layer == model.model.shared
ARTICLE = 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.'
EXPECTED = " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court."
dct = self.tok(ARTICLE, return_tensors="tf")
generated_ids = model.generate(**dct, num_beams=4)
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert (
result
== " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court."
)
assert result == EXPECTED
def test_xsum_1_1_batch_generation(self):
batch = self.tok(
......@@ -325,7 +355,6 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
truncation=True,
)
features = self.xsum_1_1_model.get_encoder()(**batch).last_hidden_state
import numpy as np
expected = np.array([[-0.0828, -0.0251, -0.0674], [0.1277, 0.3311, -0.0255], [0.2613, -0.0840, -0.2763]])
assert np.allclose(features[0, :3, :3].numpy(), expected, atol=1e-3)
......@@ -340,16 +369,14 @@ class TestTFSinusoidalPositionalEmbeddings(unittest.TestCase):
]
def test_positional_emb_cache_logic(self):
input_ids = _long_tensor([[4, 10]])
emb1 = TFSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6)
no_cache = emb1(input_ids, use_cache=False)
yes_cache = emb1(input_ids, use_cache=True)
self.assertEqual((1, 1, 6), yes_cache.shape) # extra dim to allow broadcasting, feel free to delete!
np.testing.assert_almost_equal(no_cache[-1].numpy(), yes_cache[0][0].numpy())
emb1 = TFBartSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6)
no_cache = emb1((4, 10), past_key_values_length=0)
yes_cache = emb1((4, 10), past_key_values_length=2)
self.assertTrue(no_cache.shape == yes_cache.shape == (10, 6))
self.assertListEqual(no_cache[2:].numpy().tolist(), yes_cache[:-2].numpy().tolist())
def test_positional_emb_weights_against_marian(self):
emb1 = TFSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512)
emb1 = TFBartSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512)
emb1.build(None)
weights = emb1.embeddings.numpy()
for i, (expected_weight, actual_weight) in enumerate(zip(self.desired_weights, weights)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment