Commit 268d4f20 authored by thomwolf's avatar thomwolf
Browse files

fix position biases + better tests

parent b4fcd59a
......@@ -408,7 +408,7 @@ class T5Block(nn.Module):
position_bias=position_bias,
head_mask=head_mask)
hidden_states = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
if not self.is_decoder:
hidden_states = self.layer[1](hidden_states)
......@@ -419,11 +419,11 @@ class T5Block(nn.Module):
position_bias=encoder_decoder_position_bias,
head_mask=head_mask)
hidden_states = cross_attention_outputs[0]
outputs = cross_attention_outputs[1:] + outputs
outputs = outputs + cross_attention_outputs[1:] # Keep cross-attention outputs and relative position weights
hidden_states = self.layer[2](hidden_states)
outputs = (hidden_states,) + outputs # add attentions if we output them
return outputs
return outputs # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
class T5PreTrainedModel(PreTrainedModel):
......@@ -564,14 +564,17 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i])
# layer_outputs is a tuple with:
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states = layer_outputs[0]
if i == 0:
# We share the position biases between the layers - the first layer store them
position_bias = layer_outputs[2 if self.output_attentions else 1]
if self.is_decoder:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2]
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
all_attentions = all_attentions + (layer_outputs[1],) # We keep only self-attention weights for now
hidden_states = self.final_layer_norm(hidden_states)
layer_output = self.dropout(hidden_states)
......
......@@ -45,9 +45,10 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
encoder_seq_length=7,
decoder_seq_length=9,
is_training=True,
use_input_mask=True,
use_attention_mask=True,
use_labels=True,
vocab_size=99,
n_positions=14,
......@@ -62,9 +63,10 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.encoder_seq_length = encoder_seq_length
self.decoder_seq_length = decoder_seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.n_positions = n_positions
......@@ -78,15 +80,18 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
encoder_input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
encoder_attention_mask = None
decoder_attention_mask = None
if self.use_attention_mask:
encoder_attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
token_labels = None
decoder_lm_labels = None
if self.use_labels:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
decoder_lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = T5Config(
vocab_size_or_config_json_file=self.vocab_size,
......@@ -100,21 +105,22 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor)
return (config, input_ids, input_mask, token_labels)
return (config, encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask, decoder_lm_labels)
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_and_check_t5_model(self, config, input_ids, input_mask, token_labels):
def create_and_check_t5_model(self, config, encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask, decoder_lm_labels):
model = T5Model(config=config)
model.eval()
encoder_output, decoder_output = model(encoder_input_ids=input_ids,
decoder_input_ids=input_ids,
decoder_attention_mask=input_mask)
encoder_output, decoder_output = model(encoder_input_ids=input_ids,
decoder_input_ids=input_ids)
decoder_output, encoder_output = model(encoder_input_ids=encoder_input_ids,
decoder_input_ids=decoder_input_ids,
encoder_attention_mask=encoder_attention_mask,
decoder_attention_mask=decoder_attention_mask)
decoder_output, encoder_output = model(encoder_input_ids=encoder_input_ids,
decoder_input_ids=decoder_input_ids)
result = {
"encoder_output": encoder_output,
......@@ -122,17 +128,17 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
}
self.parent.assertListEqual(
list(result["encoder_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
[self.batch_size, self.encoder_seq_length, self.hidden_size])
self.parent.assertListEqual(
list(result["decoder_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
[self.batch_size, self.decoder_seq_length, self.hidden_size])
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
def create_and_check_t5_with_lm_head(self, config, encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask, decoder_lm_labels):
model = T5WithLMHeadModel(config=config)
model.eval()
outputs = model(encoder_input_ids=input_ids, decoder_input_ids=input_ids,
decoder_attention_mask=input_mask, decoder_lm_labels=token_labels)
outputs = model(encoder_input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_lm_labels=decoder_lm_labels)
loss, prediction_scores = outputs[0], outputs[1]
result = {
"loss": loss,
......@@ -140,15 +146,17 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
[self.batch_size, self.decoder_seq_length, self.vocab_size])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, token_labels) = config_and_inputs
inputs_dict = {'encoder_input_ids': input_ids,
'decoder_input_ids': input_ids,
'decoder_attention_mask': input_mask}
(config, encoder_input_ids, decoder_input_ids, encoder_attention_mask,
decoder_attention_mask, decoder_lm_labels) = config_and_inputs
inputs_dict = {'encoder_input_ids': encoder_input_ids,
'decoder_input_ids': decoder_input_ids,
'decoder_attention_mask': decoder_attention_mask,
'encoder_attention_mask': encoder_attention_mask}
return config, inputs_dict
def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment