Commit 268d4f20 authored by thomwolf's avatar thomwolf
Browse files

fix position biases + better tests

parent b4fcd59a
...@@ -408,7 +408,7 @@ class T5Block(nn.Module): ...@@ -408,7 +408,7 @@ class T5Block(nn.Module):
position_bias=position_bias, position_bias=position_bias,
head_mask=head_mask) head_mask=head_mask)
hidden_states = self_attention_outputs[0] hidden_states = self_attention_outputs[0]
outputs = self_attention_outputs[1:] outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
if not self.is_decoder: if not self.is_decoder:
hidden_states = self.layer[1](hidden_states) hidden_states = self.layer[1](hidden_states)
...@@ -419,11 +419,11 @@ class T5Block(nn.Module): ...@@ -419,11 +419,11 @@ class T5Block(nn.Module):
position_bias=encoder_decoder_position_bias, position_bias=encoder_decoder_position_bias,
head_mask=head_mask) head_mask=head_mask)
hidden_states = cross_attention_outputs[0] hidden_states = cross_attention_outputs[0]
outputs = cross_attention_outputs[1:] + outputs outputs = outputs + cross_attention_outputs[1:] # Keep cross-attention outputs and relative position weights
hidden_states = self.layer[2](hidden_states) hidden_states = self.layer[2](hidden_states)
outputs = (hidden_states,) + outputs # add attentions if we output them outputs = (hidden_states,) + outputs # add attentions if we output them
return outputs return outputs # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
class T5PreTrainedModel(PreTrainedModel): class T5PreTrainedModel(PreTrainedModel):
...@@ -564,14 +564,17 @@ class T5Stack(T5PreTrainedModel): ...@@ -564,14 +564,17 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i]) head_mask=head_mask[i])
# layer_outputs is a tuple with:
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if i == 0: if i == 0:
# We share the position biases between the layers - the first layer store them
position_bias = layer_outputs[2 if self.output_attentions else 1] position_bias = layer_outputs[2 if self.output_attentions else 1]
if self.is_decoder: if self.is_decoder:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2] encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2]
if self.output_attentions: if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (layer_outputs[1],) # We keep only self-attention weights for now
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
layer_output = self.dropout(hidden_states) layer_output = self.dropout(hidden_states)
......
...@@ -45,9 +45,10 @@ class T5ModelTest(CommonTestCases.CommonModelTester): ...@@ -45,9 +45,10 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
def __init__(self, def __init__(self,
parent, parent,
batch_size=13, batch_size=13,
seq_length=7, encoder_seq_length=7,
decoder_seq_length=9,
is_training=True, is_training=True,
use_input_mask=True, use_attention_mask=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
n_positions=14, n_positions=14,
...@@ -62,9 +63,10 @@ class T5ModelTest(CommonTestCases.CommonModelTester): ...@@ -62,9 +63,10 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.seq_length = seq_length self.encoder_seq_length = encoder_seq_length
self.decoder_seq_length = decoder_seq_length
self.is_training = is_training self.is_training = is_training
self.use_input_mask = use_input_mask self.use_attention_mask = use_attention_mask
self.use_labels = use_labels self.use_labels = use_labels
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_positions = n_positions self.n_positions = n_positions
...@@ -78,15 +80,18 @@ class T5ModelTest(CommonTestCases.CommonModelTester): ...@@ -78,15 +80,18 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
self.scope = scope self.scope = scope
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) encoder_input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
input_mask = None encoder_attention_mask = None
if self.use_input_mask: decoder_attention_mask = None
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) if self.use_attention_mask:
encoder_attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
token_labels = None decoder_lm_labels = None
if self.use_labels: if self.use_labels:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) decoder_lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = T5Config( config = T5Config(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
...@@ -100,21 +105,22 @@ class T5ModelTest(CommonTestCases.CommonModelTester): ...@@ -100,21 +105,22 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
dropout_rate=self.dropout_rate, dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor) initializer_factor=self.initializer_factor)
return (config, input_ids, input_mask, token_labels) return (config, encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask, decoder_lm_labels)
def check_loss_output(self, result): def check_loss_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss"].size()), list(result["loss"].size()),
[]) [])
def create_and_check_t5_model(self, config, input_ids, input_mask, token_labels): def create_and_check_t5_model(self, config, encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask, decoder_lm_labels):
model = T5Model(config=config) model = T5Model(config=config)
model.eval() model.eval()
encoder_output, decoder_output = model(encoder_input_ids=input_ids, decoder_output, encoder_output = model(encoder_input_ids=encoder_input_ids,
decoder_input_ids=input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=input_mask) encoder_attention_mask=encoder_attention_mask,
encoder_output, decoder_output = model(encoder_input_ids=input_ids, decoder_attention_mask=decoder_attention_mask)
decoder_input_ids=input_ids) decoder_output, encoder_output = model(encoder_input_ids=encoder_input_ids,
decoder_input_ids=decoder_input_ids)
result = { result = {
"encoder_output": encoder_output, "encoder_output": encoder_output,
...@@ -122,17 +128,17 @@ class T5ModelTest(CommonTestCases.CommonModelTester): ...@@ -122,17 +128,17 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
} }
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["encoder_output"].size()), list(result["encoder_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size]) [self.batch_size, self.encoder_seq_length, self.hidden_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["decoder_output"].size()), list(result["decoder_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size]) [self.batch_size, self.decoder_seq_length, self.hidden_size])
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels): def create_and_check_t5_with_lm_head(self, config, encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask, decoder_lm_labels):
model = T5WithLMHeadModel(config=config) model = T5WithLMHeadModel(config=config)
model.eval() model.eval()
outputs = model(encoder_input_ids=input_ids, decoder_input_ids=input_ids, outputs = model(encoder_input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=input_mask, decoder_lm_labels=token_labels) decoder_attention_mask=decoder_attention_mask, decoder_lm_labels=decoder_lm_labels)
loss, prediction_scores = outputs[0], outputs[1] loss, prediction_scores = outputs[0], outputs[1]
result = { result = {
"loss": loss, "loss": loss,
...@@ -140,15 +146,17 @@ class T5ModelTest(CommonTestCases.CommonModelTester): ...@@ -140,15 +146,17 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
} }
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["prediction_scores"].size()), list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.decoder_seq_length, self.vocab_size])
self.check_loss_output(result) self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, token_labels) = config_and_inputs (config, encoder_input_ids, decoder_input_ids, encoder_attention_mask,
inputs_dict = {'encoder_input_ids': input_ids, decoder_attention_mask, decoder_lm_labels) = config_and_inputs
'decoder_input_ids': input_ids, inputs_dict = {'encoder_input_ids': encoder_input_ids,
'decoder_attention_mask': input_mask} 'decoder_input_ids': decoder_input_ids,
'decoder_attention_mask': decoder_attention_mask,
'encoder_attention_mask': encoder_attention_mask}
return config, inputs_dict return config, inputs_dict
def setUp(self): def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment