Commit d9184620 authored by thomwolf's avatar thomwolf
Browse files

fix tests and new API

parent 213981d8
......@@ -320,9 +320,6 @@ class BertSelfAttention(nn.Module):
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
if self.keep_multihead_output:
self.multihead_output = context_layer
self.multihead_output.retain_grad()
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
......@@ -416,7 +413,8 @@ class BertLayer(nn.Module):
def forward(self, hidden_states, attention_mask, head_mask=None):
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
intermediate_output = self.intermediate(attention_outputs[0])
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
outputs = [layer_output] + attention_outputs[1:] # add attentions if we output them
return outputs
......@@ -571,8 +569,7 @@ class BertModel(BertPreTrainedModel):
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
......@@ -688,8 +685,7 @@ class BertForPreTraining(BertPreTrainedModel):
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
......@@ -770,8 +766,7 @@ class BertForMaskedLM(BertPreTrainedModel):
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
......@@ -845,8 +840,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
......@@ -919,8 +913,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
......@@ -1003,8 +996,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
......@@ -1085,8 +1077,7 @@ class BertForTokenClassification(BertPreTrainedModel):
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
......@@ -1170,8 +1161,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
......
......@@ -504,10 +504,10 @@ class XLNetRelativeAttention(nn.Module):
output_h = self.post_attention(h, attn_vec)
output_g = None
outputs = [output_h, output_g]
if self.output_attentions:
return output_h, output_g, attn_prob
return output_h, output_g
outputs = outputs + [attn_prob]
return outputs
class XLNetFeedForward(nn.Module):
def __init__(self, config):
......@@ -867,7 +867,7 @@ class XLNetModel(XLNetPreTrainedModel):
outputs = layer_module(output_h, output_g, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
r=pos_emb, seg_mat=seg_mat, mems=mems[i], target_mapping=target_mapping,
head_mask=head_mask)
head_mask=head_mask[i])
output_h, output_g = outputs[:2]
if self.output_attentions:
attentions.append(outputs[2:])
......
......@@ -123,9 +123,13 @@ class BertModelTest(unittest.TestCase):
def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertModel(config=config)
model.eval()
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
sequence_output, pooled_output = model(input_ids, token_type_ids, input_mask)
model = BertModel(config=config, output_hidden_states=True)
model.eval()
_, _, all_encoder_layers = model(input_ids, token_type_ids, input_mask)
outputs = {
"sequence_output": all_encoder_layers[-1],
"sequence_output": sequence_output,
"pooled_output": pooled_output,
"all_encoder_layers": all_encoder_layers,
}
......@@ -134,7 +138,7 @@ class BertModelTest(unittest.TestCase):
def check_bert_model_output(self, result):
self.parent.assertListEqual(
[size for layer in result["all_encoder_layers"] for size in layer.size()],
[self.batch_size, self.seq_length, self.hidden_size] * self.num_hidden_layers)
[self.batch_size, self.seq_length, self.hidden_size] * (self.num_hidden_layers + 1))
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
......@@ -144,8 +148,7 @@ class BertModelTest(unittest.TestCase):
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMaskedLM(config=config)
model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels)
prediction_scores = model(input_ids, token_type_ids, input_mask)
loss, prediction_scores = model(input_ids, token_type_ids, input_mask, token_labels)
outputs = {
"loss": loss,
"prediction_scores": prediction_scores,
......@@ -160,8 +163,7 @@ class BertModelTest(unittest.TestCase):
def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForNextSentencePrediction(config=config)
model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
seq_relationship_score = model(input_ids, token_type_ids, input_mask)
loss, seq_relationship_score = model(input_ids, token_type_ids, input_mask, sequence_labels)
outputs = {
"loss": loss,
"seq_relationship_score": seq_relationship_score,
......@@ -177,8 +179,7 @@ class BertModelTest(unittest.TestCase):
def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForPreTraining(config=config)
model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask)
loss, prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
outputs = {
"loss": loss,
"prediction_scores": prediction_scores,
......@@ -198,8 +199,7 @@ class BertModelTest(unittest.TestCase):
def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForQuestionAnswering(config=config)
model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
loss, start_logits, end_logits = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
outputs = {
"loss": loss,
"start_logits": start_logits,
......@@ -219,8 +219,7 @@ class BertModelTest(unittest.TestCase):
def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForSequenceClassification(config=config, num_labels=self.num_labels)
model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
logits = model(input_ids, token_type_ids, input_mask)
loss, logits = model(input_ids, token_type_ids, input_mask, sequence_labels)
outputs = {
"loss": loss,
"logits": logits,
......@@ -236,8 +235,7 @@ class BertModelTest(unittest.TestCase):
def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForTokenClassification(config=config, num_labels=self.num_labels)
model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels)
logits = model(input_ids, token_type_ids, input_mask)
loss, logits = model(input_ids, token_type_ids, input_mask, token_labels)
outputs = {
"loss": loss,
"logits": logits,
......@@ -256,13 +254,10 @@ class BertModelTest(unittest.TestCase):
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss = model(multiple_choice_inputs_ids,
loss, logits = model(multiple_choice_inputs_ids,
multiple_choice_token_type_ids,
multiple_choice_input_mask,
choice_labels)
logits = model(multiple_choice_inputs_ids,
multiple_choice_token_type_ids,
multiple_choice_input_mask)
outputs = {
"loss": loss,
"logits": logits,
......@@ -285,8 +280,8 @@ class BertModelTest(unittest.TestCase):
else:
model = model_class(config=config, output_attentions=True)
model.eval()
output = model(input_ids, token_type_ids, input_mask)
attentions = output[0]
outputs = model(input_ids, token_type_ids, input_mask)
attentions = outputs[-1]
self.parent.assertEqual(len(attentions), self.num_hidden_layers)
self.parent.assertListEqual(
list(attentions[0].size()),
......@@ -300,57 +295,56 @@ class BertModelTest(unittest.TestCase):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config,
num_labels=self.num_labels,
keep_multihead_output=True)
num_labels=self.num_labels)
else:
model = model_class(config=config, keep_multihead_output=True)
model = model_class(config=config)
model.eval()
head_mask = torch.ones(self.num_hidden_layers, self.num_attention_heads).to(input_ids.device)
head_mask[0, 1:-1] = 0.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 0.0 # Mask all but the first head on the last layer
output = model(input_ids, token_type_ids, input_mask, head_mask=head_mask)
# Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask.requires_grad_(requires_grad=True)
outputs = model(input_ids, token_type_ids, input_mask, head_mask=head_mask)
if isinstance(model, BertModel):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output)
# Compute some gradients
output = sum(t.sum() for t in outputs[0])
output = output.sum()
output.backward()
multihead_outputs = (model if isinstance(model, BertModel) else model.bert).get_multihead_outputs()
multihead_outputs = head_mask.grad
self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertEqual(
len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[0][:, 0, :, :].nonzero()),
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
self.parent.assertEqual(
len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertEqual(
len(multihead_outputs[1].nonzero()),
multihead_outputs[1].numel())
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertEqual(
len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[-1][:, 0, :, :].nonzero()),
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[1].nonzero()),
# multihead_outputs[1].numel())
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def create_and_check_bert_for_head_pruning(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
......@@ -360,38 +354,34 @@ class BertModelTest(unittest.TestCase):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config,
num_labels=self.num_labels,
keep_multihead_output=True)
num_labels=self.num_labels)
else:
model = model_class(config=config, keep_multihead_output=True)
model = model_class(config=config)
model.eval()
bert_model = model if isinstance(model, BertModel) else model.bert
heads_to_prune = {0: list(range(1, self.num_attention_heads)),
-1: [0]}
bert_model.prune_heads(heads_to_prune)
output = model(input_ids, token_type_ids, input_mask)
if isinstance(model, BertModel):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output)
output = output.sum()
output.backward()
multihead_outputs = bert_model.get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size, 1,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size, self.num_attention_heads-1,
self.seq_length, self.hidden_size // self.num_attention_heads])
outputs = model(input_ids, token_type_ids, input_mask)
# output = sum(t.sum() for t in outputs[0])
# output = output.sum()
# output.backward()
# multihead_outputs = bert_model.get_multihead_outputs()
# self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, 1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads-1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
def test_default(self):
......
......@@ -134,26 +134,19 @@ class XLNetModelTest(unittest.TestCase):
model = XLNetLMHeadModel(config)
model.eval()
loss_1, mems_1a = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
all_logits_1, mems_1b = model(input_ids_1, token_type_ids=segment_ids)
loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
loss_2, mems_2a = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1a)
all_logits_2, mems_2b = model(input_ids_2, token_type_ids=segment_ids, mems=mems_1b)
loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
logits, _ = model(input_ids_q,
perm_mask=perm_mask,
target_mapping=target_mapping,
inp_q=inp_q)
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping, inp_q=inp_q)
outputs = {
"loss_1": loss_1,
"mems_1a": mems_1a,
"mems_1": mems_1,
"all_logits_1": all_logits_1,
"mems_1b": mems_1b,
"loss_2": loss_2,
"mems_2a": mems_2a,
"mems_2": mems_2,
"all_logits_2": all_logits_2,
"mems_2b": mems_2b,
}
return outputs
......@@ -165,14 +158,8 @@ class XLNetModelTest(unittest.TestCase):
list(result["all_logits_1"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1a"]),
list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1b"]),
[[self.seq_length, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1b"]))
self.parent.assertListEqual(
list(result["loss_2"].size()),
......@@ -181,14 +168,8 @@ class XLNetModelTest(unittest.TestCase):
list(result["all_logits_2"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2a"]),
list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2b"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2b"]))
def test_default(self):
self.run_tester(XLNetModelTest.XLNetModelTester(self))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment