Unverified Commit abc573f5 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[TF Bart] Refactor TFBart (#9029)

* reorder file

* delete unnecesarry function

* make style

* save intermediate

* fix attention masks

* correct tf bart past key values

* solve merge conflict bug

* correct tensor dims

* save intermediate tf

* change attn layer

* fix typo re-order past

* inputs_embeds

* make fix copies

* finish tests

* fix graph mode

* appyl lysandres suggestions
parent 389aba34
...@@ -717,7 +717,7 @@ if is_tf_available(): ...@@ -717,7 +717,7 @@ if is_tf_available():
TFAutoModelForTokenClassification, TFAutoModelForTokenClassification,
TFAutoModelWithLMHead, TFAutoModelWithLMHead,
) )
from .models.bart import TFBartForConditionalGeneration, TFBartModel from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
from .models.bert import ( from .models.bert import (
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBertEmbeddings, TFBertEmbeddings,
......
...@@ -36,4 +36,4 @@ if is_torch_available(): ...@@ -36,4 +36,4 @@ if is_torch_available():
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
...@@ -215,10 +215,10 @@ class BartAttention(nn.Module): ...@@ -215,10 +215,10 @@ class BartAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
attn_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -274,14 +274,14 @@ class BartAttention(nn.Module): ...@@ -274,14 +274,14 @@ class BartAttention(nn.Module):
src_len, src_len,
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
if attn_mask is not None: if attention_mask is not None:
assert attn_mask.size() == ( assert attention_mask.size() == (
bsz, bsz,
1, 1,
tgt_len, tgt_len,
src_len, src_len,
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attn_mask.size()}" ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.softmax(attn_weights, dim=-1)
...@@ -335,23 +335,19 @@ class BartEncoderLayer(nn.Module): ...@@ -335,23 +335,19 @@ class BartEncoderLayer(nn.Module):
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = BartLayerNorm(self.embed_dim) self.final_layer_norm = BartLayerNorm(self.embed_dim)
def forward( def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
self, hidden_states: torch.Tensor, encoder_padding_mask: torch.Tensor, output_attentions: bool = False
):
""" """
Args: Args:
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (:obj:`bool`): Whether the base model outputs attentions. This requires the attentions tensor to be reshaped in this function.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
""" """
residual = hidden_states residual = hidden_states
if self.normalize_before: if self.normalize_before:
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.self_attn( hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attn_mask=encoder_padding_mask, output_attentions=output_attentions hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -405,24 +401,35 @@ class BartDecoderLayer(nn.Module): ...@@ -405,24 +401,35 @@ class BartDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
encoder_attn_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
attn_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[torch.Tensor] = False, output_attentions: Optional[torch.Tensor] = False,
): ):
"""
Args:
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`): Whether the base model outputs attentions. This requires the attentions tensor to be reshaped in this function.
"""
residual = hidden_states residual = hidden_states
if self.normalize_before: if self.normalize_before:
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple # add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attn_mask=attn_mask, attention_mask=attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
...@@ -443,7 +450,7 @@ class BartDecoderLayer(nn.Module): ...@@ -443,7 +450,7 @@ class BartDecoderLayer(nn.Module):
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attn_mask=encoder_attn_mask, attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -905,9 +912,9 @@ class BartDecoder(BartPretrainedModel): ...@@ -905,9 +912,9 @@ class BartDecoder(BartPretrainedModel):
# create causal mask # create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attn_mask = None combined_attention_mask = None
if input_shape[-1] > 1: if input_shape[-1] > 1:
attn_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device) ).to(self.device)
...@@ -928,9 +935,9 @@ class BartDecoder(BartPretrainedModel): ...@@ -928,9 +935,9 @@ class BartDecoder(BartPretrainedModel):
# never mask leading token, even if it is pad # never mask leading token, even if it is pad
attention_mask[:, 0] = attention_mask[:, 1] attention_mask[:, 0] = attention_mask[:, 1]
if attention_mask is not None and attn_mask is not None: if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attn_mask = attn_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, past_key_values_length=past_key_values_length attention_mask, inputs_embeds.dtype, past_key_values_length=past_key_values_length
) )
...@@ -968,9 +975,9 @@ class BartDecoder(BartPretrainedModel): ...@@ -968,9 +975,9 @@ class BartDecoder(BartPretrainedModel):
hidden_states, layer_self_attn, present_key_value, layer_cross_attn = decoder_layer( hidden_states, layer_self_attn, present_key_value, layer_cross_attn = decoder_layer(
hidden_states, hidden_states,
encoder_hidden_states, attention_mask=combined_attention_mask,
encoder_attn_mask=encoder_attention_mask, encoder_hidden_states=encoder_hidden_states,
attn_mask=attn_mask, encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
......
...@@ -1305,14 +1305,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1305,14 +1305,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# get decoder inputs from shifting lm labels to the right # get decoder inputs from shifting lm labels to the right
inputs["decoder_input_ids"] = self._shift_right(inputs["labels"]) inputs["decoder_input_ids"] = self._shift_right(inputs["labels"])
# If decoding with past key value states, only the last tokens
# should be given as an input
if inputs["past_key_values"] is not None:
if inputs["decoder_input_ids"] is not None:
inputs["decoder_input_ids"] = inputs["decoder_input_ids"][:, -1:]
if inputs["decoder_inputs_embeds"] is not None:
inputs["decoder_inputs_embeds"] = inputs["decoder_inputs_embeds"][:, -1:]
# Decode # Decode
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
inputs["decoder_input_ids"], inputs["decoder_input_ids"],
......
...@@ -256,6 +256,15 @@ class TFBartModel: ...@@ -256,6 +256,15 @@ class TFBartModel:
requires_tf(self) requires_tf(self)
class TFBartPretrainedModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -207,7 +207,7 @@ class BartModelTester: ...@@ -207,7 +207,7 @@ class BartModelTester:
@require_torch @require_torch
class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering) (BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering)
if is_torch_available() if is_torch_available()
......
...@@ -30,7 +30,7 @@ if is_tf_available(): ...@@ -30,7 +30,7 @@ if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers import TFBartForConditionalGeneration, TFBartModel from transformers import TFBartForConditionalGeneration, TFBartModel
from transformers.models.bart.modeling_tf_bart import TFSinusoidalPositionalEmbedding from transformers.models.bart.modeling_tf_bart import TFBartSinusoidalPositionalEmbedding
@require_tf @require_tf
...@@ -85,6 +85,38 @@ class TFBartModelTester: ...@@ -85,6 +85,38 @@ class TFBartModelTester:
inputs_dict = prepare_bart_inputs_dict(config, input_ids) inputs_dict = prepare_bart_inputs_dict(config, input_ids)
return config, inputs_dict return config, inputs_dict
def check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = TFBartModel(config=config).get_decoder()
input_ids = inputs_dict["input_ids"]
input_ids = input_ids[:1, :]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids)[0]
output_from_past = model(next_tokens, past_key_values=past_key_values)[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def prepare_bart_inputs_dict( def prepare_bart_inputs_dict(
config, config,
...@@ -114,9 +146,9 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -114,9 +146,9 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
def test_inputs_embeds(self): def test_decoder_model_past_large_inputs(self):
# inputs_embeds not supported config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
pass self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -285,13 +317,11 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase): ...@@ -285,13 +317,11 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
model = self.xsum_1_1_model model = self.xsum_1_1_model
assert model.model.decoder.embed_tokens._layer == model.model.shared assert model.model.decoder.embed_tokens._layer == model.model.shared
ARTICLE = 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.' ARTICLE = 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.'
EXPECTED = " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court."
dct = self.tok(ARTICLE, return_tensors="tf") dct = self.tok(ARTICLE, return_tensors="tf")
generated_ids = model.generate(**dct, num_beams=4) generated_ids = model.generate(**dct, num_beams=4)
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0] result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert ( assert result == EXPECTED
result
== " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court."
)
def test_xsum_1_1_batch_generation(self): def test_xsum_1_1_batch_generation(self):
batch = self.tok( batch = self.tok(
...@@ -325,7 +355,6 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase): ...@@ -325,7 +355,6 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
truncation=True, truncation=True,
) )
features = self.xsum_1_1_model.get_encoder()(**batch).last_hidden_state features = self.xsum_1_1_model.get_encoder()(**batch).last_hidden_state
import numpy as np
expected = np.array([[-0.0828, -0.0251, -0.0674], [0.1277, 0.3311, -0.0255], [0.2613, -0.0840, -0.2763]]) expected = np.array([[-0.0828, -0.0251, -0.0674], [0.1277, 0.3311, -0.0255], [0.2613, -0.0840, -0.2763]])
assert np.allclose(features[0, :3, :3].numpy(), expected, atol=1e-3) assert np.allclose(features[0, :3, :3].numpy(), expected, atol=1e-3)
...@@ -340,16 +369,14 @@ class TestTFSinusoidalPositionalEmbeddings(unittest.TestCase): ...@@ -340,16 +369,14 @@ class TestTFSinusoidalPositionalEmbeddings(unittest.TestCase):
] ]
def test_positional_emb_cache_logic(self): def test_positional_emb_cache_logic(self):
input_ids = _long_tensor([[4, 10]]) emb1 = TFBartSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6)
emb1 = TFSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6) no_cache = emb1((4, 10), past_key_values_length=0)
no_cache = emb1(input_ids, use_cache=False) yes_cache = emb1((4, 10), past_key_values_length=2)
yes_cache = emb1(input_ids, use_cache=True) self.assertTrue(no_cache.shape == yes_cache.shape == (10, 6))
self.assertEqual((1, 1, 6), yes_cache.shape) # extra dim to allow broadcasting, feel free to delete! self.assertListEqual(no_cache[2:].numpy().tolist(), yes_cache[:-2].numpy().tolist())
np.testing.assert_almost_equal(no_cache[-1].numpy(), yes_cache[0][0].numpy())
def test_positional_emb_weights_against_marian(self): def test_positional_emb_weights_against_marian(self):
emb1 = TFSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512) emb1 = TFBartSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512)
emb1.build(None) emb1.build(None)
weights = emb1.embeddings.numpy() weights = emb1.embeddings.numpy()
for i, (expected_weight, actual_weight) in enumerate(zip(self.desired_weights, weights)): for i, (expected_weight, actual_weight) in enumerate(zip(self.desired_weights, weights)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment