"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c48b16b8da991ba87ca82ed03e33481fc712fbfa"
Unverified Commit b815edf6 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[T5, Testst] Add extensive hard-coded integration tests and make sure PT and...

[T5, Testst] Add extensive hard-coded integration tests and make sure PT and TF give equal results (#3550)

* add some t5 integration tests

* finish summarization and translation integration tests for T5 - results loook good

* add tf test

* fix == vs is bug

* fix tf beam search error and make tf t5 tests pass
parent 8538ce90
...@@ -119,7 +119,7 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -119,7 +119,7 @@ class TFT5Attention(tf.keras.layers.Layer):
if self.has_relative_attention_bias: if self.has_relative_attention_bias:
self.relative_attention_bias = tf.keras.layers.Embedding( self.relative_attention_bias = tf.keras.layers.Embedding(
self.relative_attention_num_buckets, self.n_heads, name="relative_attention_bias" self.relative_attention_num_buckets, self.n_heads, name="relative_attention_bias",
) )
self.pruned_heads = set() self.pruned_heads = set()
...@@ -178,13 +178,15 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -178,13 +178,15 @@ class TFT5Attention(tf.keras.layers.Layer):
memory_position = tf.range(klen)[None, :] memory_position = tf.range(klen)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen) relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket( rp_bucket = self._relative_position_bucket(
relative_position, bidirectional=not self.is_decoder, num_buckets=self.relative_attention_num_buckets relative_position, bidirectional=not self.is_decoder, num_buckets=self.relative_attention_num_buckets,
) )
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen) values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen)
return values return values
def call(self, input, mask=None, kv=None, position_bias=None, cache=None, head_mask=None, training=False): def call(
self, input, mask=None, kv=None, position_bias=None, cache=None, head_mask=None, training=False,
):
""" """
Self-attention (if kv is None) or attention over source sentence (provided by kv). Self-attention (if kv is None) or attention over source sentence (provided by kv).
""" """
...@@ -261,15 +263,17 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer): ...@@ -261,15 +263,17 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, has_relative_attention_bias=False, **kwargs): def __init__(self, config, has_relative_attention_bias=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.SelfAttention = TFT5Attention( self.SelfAttention = TFT5Attention(
config, has_relative_attention_bias=has_relative_attention_bias, name="SelfAttention" config, has_relative_attention_bias=has_relative_attention_bias, name="SelfAttention",
) )
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, attention_mask=None, position_bias=None, head_mask=None, training=False): def call(
self, hidden_states, attention_mask=None, position_bias=None, head_mask=None, training=False,
):
norm_x = self.layer_norm(hidden_states) norm_x = self.layer_norm(hidden_states)
attention_output = self.SelfAttention( attention_output = self.SelfAttention(
norm_x, mask=attention_mask, position_bias=position_bias, head_mask=head_mask, training=training norm_x, mask=attention_mask, position_bias=position_bias, head_mask=head_mask, training=training,
) )
y = attention_output[0] y = attention_output[0]
layer_output = hidden_states + self.dropout(y, training=training) layer_output = hidden_states + self.dropout(y, training=training)
...@@ -281,15 +285,17 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer): ...@@ -281,15 +285,17 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
def __init__(self, config, has_relative_attention_bias=False, **kwargs): def __init__(self, config, has_relative_attention_bias=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.EncDecAttention = TFT5Attention( self.EncDecAttention = TFT5Attention(
config, has_relative_attention_bias=has_relative_attention_bias, name="EncDecAttention" config, has_relative_attention_bias=has_relative_attention_bias, name="EncDecAttention",
) )
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, kv, attention_mask=None, position_bias=None, head_mask=None, training=False): def call(
self, hidden_states, kv, attention_mask=None, position_bias=None, head_mask=None, training=False,
):
norm_x = self.layer_norm(hidden_states) norm_x = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention( attention_output = self.EncDecAttention(
norm_x, mask=attention_mask, kv=kv, position_bias=position_bias, head_mask=head_mask, training=training norm_x, mask=attention_mask, kv=kv, position_bias=position_bias, head_mask=head_mask, training=training,
) )
y = attention_output[0] y = attention_output[0]
layer_output = hidden_states + self.dropout(y, training=training) layer_output = hidden_states + self.dropout(y, training=training)
...@@ -303,12 +309,12 @@ class TFT5Block(tf.keras.layers.Layer): ...@@ -303,12 +309,12 @@ class TFT5Block(tf.keras.layers.Layer):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.layer = [] self.layer = []
self.layer.append( self.layer.append(
TFT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._0") TFT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._0",)
) )
if self.is_decoder: if self.is_decoder:
self.layer.append( self.layer.append(
TFT5LayerCrossAttention( TFT5LayerCrossAttention(
config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._1" config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._1",
) )
) )
self.layer.append(TFT5LayerFF(config, name="layer_._2")) self.layer.append(TFT5LayerFF(config, name="layer_._2"))
...@@ -402,7 +408,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -402,7 +408,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
self.num_hidden_layers = config.num_layers self.num_hidden_layers = config.num_layers
self.block = [ self.block = [
TFT5Block(config, has_relative_attention_bias=bool(i == 0), name="block_._{}".format(i)) TFT5Block(config, has_relative_attention_bias=bool(i == 0), name="block_._{}".format(i),)
for i in range(config.num_layers) for i in range(config.num_layers)
] ]
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm") self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm")
...@@ -469,7 +475,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -469,7 +475,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if self.config.is_decoder: if self.config.is_decoder:
seq_ids = tf.range(seq_length) seq_ids = tf.range(seq_length)
causal_mask = tf.less_equal( causal_mask = tf.less_equal(
tf.tile(seq_ids[None, None, :], (batch_size, seq_length, 1)), seq_ids[None, :, None] tf.tile(seq_ids[None, None, :], (batch_size, seq_length, 1)), seq_ids[None, :, None],
) )
causal_mask = tf.cast(causal_mask, dtype=tf.float32) causal_mask = tf.cast(causal_mask, dtype=tf.float32)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
...@@ -748,7 +754,7 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -748,7 +754,7 @@ class TFT5Model(TFT5PreTrainedModel):
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if encoder_outputs is None: if encoder_outputs is None:
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask,
) )
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
...@@ -852,7 +858,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): ...@@ -852,7 +858,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
if encoder_outputs is None: if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed # Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask,
) )
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
......
...@@ -1112,12 +1112,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1112,12 +1112,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams] assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
# next batch beam content # next batch beam content
# list of (batch_size * num_beams) tuple(next hypothesis score, next token, current position in the batch)
next_batch_beam = [] next_batch_beam = []
# for each sentence # for each sentence
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
# if we are done with this sentence
if done[batch_idx]: if done[batch_idx]:
assert ( assert (
len(generated_hyps[batch_idx]) >= num_beams len(generated_hyps[batch_idx]) >= num_beams
...@@ -1135,14 +1135,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1135,14 +1135,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx]) zip(next_tokens[batch_idx], next_scores[batch_idx])
): ):
# get beam and token IDs # get beam and token IDs
beam_id = beam_token_id // vocab_size beam_id = beam_token_id // vocab_size
token_id = beam_token_id % vocab_size token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence or last iteration # add to generated hypotheses if end of sentence or last iteration
if eos_token_id is not None and token_id.numpy() is eos_token_id: if (eos_token_id is not None) and (token_id.numpy() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added # if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams: if is_beam_token_worse_than_top_num_beams:
...@@ -1158,9 +1157,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1158,9 +1157,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if len(next_sent_beam) == num_beams: if len(next_sent_beam) == num_beams:
break break
# if we are done with this sentence # Check if were done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
tf.reduce_max(next_scores[batch_idx]).numpy() tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len=cur_len
) )
# update next beam content # update next beam content
...@@ -1178,6 +1177,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1178,6 +1177,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32) beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32)
beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32) beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32)
print("Scores: {}-{}".format(cur_len, beam_scores.numpy()))
# re-order batch # re-order batch
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx]) input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1) input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
...@@ -1185,6 +1186,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1185,6 +1186,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if past is not None: if past is not None:
past = self._reorder_cache(past, beam_idx) past = self._reorder_cache(past, beam_idx)
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False: if self.config.is_encoder_decoder is False:
attention_mask = tf.concat( attention_mask = tf.concat(
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1 [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
...@@ -1244,16 +1246,26 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1244,16 +1246,26 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# fill with hypothesis and eos_token_id if necessary # fill with hypothesis and eos_token_id if necessary
for i, hypo in enumerate(best): for i, hypo in enumerate(best):
padding = tf.ones((sent_max_len - shape_list(hypo)[0],), dtype=tf.int32) * pad_token_id assert sent_lengths[i] == shape_list(hypo)[0]
decoded_hypo = tf.concat([hypo, padding], axis=0) # if sent_length is max_len do not pad
if sent_lengths[i] == sent_max_len:
if sent_lengths[i] < max_length: decoded_slice = hypo
decoded_hypo = tf.where( else:
tf.range(max_length) == sent_lengths[i], # else pad to sent_max_len
eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32), num_pad_tokens = sent_max_len - sent_lengths[i]
decoded_hypo, padding = pad_token_id * tf.ones((num_pad_tokens,), dtype=tf.int32)
) decoded_slice = tf.concat([hypo, padding], axis=-1)
decoded_list.append(decoded_hypo)
# finish sentence with EOS token
if sent_lengths[i] < max_length:
decoded_slice = tf.where(
tf.range(sent_max_len, dtype=tf.int32) == sent_lengths[i],
eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32),
decoded_slice,
)
# add to list
decoded_list.append(decoded_slice)
decoded = tf.stack(decoded_list) decoded = tf.stack(decoded_list)
else: else:
# none of the hypotheses have an eos_token # none of the hypotheses have an eos_token
......
...@@ -960,6 +960,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -960,6 +960,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
.to(input_ids.device) .to(input_ids.device)
) )
encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:]) encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:])
else: else:
encoder_outputs = None encoder_outputs = None
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
...@@ -1284,14 +1285,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1284,14 +1285,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx]) zip(next_tokens[batch_idx], next_scores[batch_idx])
): ):
# get beam and word IDs # get beam and token IDs
beam_id = beam_token_id // vocab_size beam_id = beam_token_id // vocab_size
token_id = beam_token_id % vocab_size token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence or last iteration
# add to generated hypotheses if end of sentence if (eos_token_id is not None) and (token_id.item() == eos_token_id):
if (eos_token_id is not None) and (token_id.item() is eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added # if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams: if is_beam_token_worse_than_top_num_beams:
...@@ -1300,7 +1300,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1300,7 +1300,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids[effective_beam_id].clone(), beam_token_score.item(), input_ids[effective_beam_id].clone(), beam_token_score.item(),
) )
else: else:
# add next predicted word if it is not eos_token # add next predicted token if it is not eos_token
next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# the beam for next step is full # the beam for next step is full
...@@ -1330,7 +1330,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1330,7 +1330,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# re-order batch # re-order batch
input_ids = input_ids[beam_idx, :] input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1) input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
# re-order internal states # re-order internal states
if past is not None: if past is not None:
past = self._reorder_cache(past, beam_idx) past = self._reorder_cache(past, beam_idx)
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment