Commit d9184620 authored by thomwolf's avatar thomwolf
Browse files

fix tests and new API

parent 213981d8
...@@ -320,9 +320,6 @@ class BertSelfAttention(nn.Module): ...@@ -320,9 +320,6 @@ class BertSelfAttention(nn.Module):
attention_probs = attention_probs * head_mask attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer) context_layer = torch.matmul(attention_probs, value_layer)
if self.keep_multihead_output:
self.multihead_output = context_layer
self.multihead_output.retain_grad()
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
...@@ -416,7 +413,8 @@ class BertLayer(nn.Module): ...@@ -416,7 +413,8 @@ class BertLayer(nn.Module):
def forward(self, hidden_states, attention_mask, head_mask=None): def forward(self, hidden_states, attention_mask, head_mask=None):
attention_outputs = self.attention(hidden_states, attention_mask, head_mask) attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
intermediate_output = self.intermediate(attention_outputs[0]) attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
outputs = [layer_output] + attention_outputs[1:] # add attentions if we output them outputs = [layer_output] + attention_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -571,8 +569,7 @@ class BertModel(BertPreTrainedModel): ...@@ -571,8 +569,7 @@ class BertModel(BertPreTrainedModel):
Params: Params:
`config`: a BertConfig class instance with the configuration to build a new model `config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. `output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs: Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
...@@ -688,8 +685,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -688,8 +685,7 @@ class BertForPreTraining(BertPreTrainedModel):
Params: Params:
`config`: a BertConfig class instance with the configuration to build a new model `config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. `output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs: Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
...@@ -770,8 +766,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -770,8 +766,7 @@ class BertForMaskedLM(BertPreTrainedModel):
Params: Params:
`config`: a BertConfig class instance with the configuration to build a new model `config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. `output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs: Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
...@@ -845,8 +840,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -845,8 +840,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
Params: Params:
`config`: a BertConfig class instance with the configuration to build a new model `config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. `output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs: Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
...@@ -919,8 +913,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -919,8 +913,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
Params: Params:
`config`: a BertConfig class instance with the configuration to build a new model `config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. `output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
`num_labels`: the number of classes for the classifier. Default = 2. `num_labels`: the number of classes for the classifier. Default = 2.
Inputs: Inputs:
...@@ -1003,8 +996,7 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1003,8 +996,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
Params: Params:
`config`: a BertConfig class instance with the configuration to build a new model `config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. `output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
`num_choices`: the number of classes for the classifier. Default = 2. `num_choices`: the number of classes for the classifier. Default = 2.
Inputs: Inputs:
...@@ -1085,8 +1077,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1085,8 +1077,7 @@ class BertForTokenClassification(BertPreTrainedModel):
Params: Params:
`config`: a BertConfig class instance with the configuration to build a new model `config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. `output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
`num_labels`: the number of classes for the classifier. Default = 2. `num_labels`: the number of classes for the classifier. Default = 2.
Inputs: Inputs:
...@@ -1170,8 +1161,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1170,8 +1161,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
Params: Params:
`config`: a BertConfig class instance with the configuration to build a new model `config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. `output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
This can be used to compute head importance metrics. Default: False
Inputs: Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
......
...@@ -504,10 +504,10 @@ class XLNetRelativeAttention(nn.Module): ...@@ -504,10 +504,10 @@ class XLNetRelativeAttention(nn.Module):
output_h = self.post_attention(h, attn_vec) output_h = self.post_attention(h, attn_vec)
output_g = None output_g = None
outputs = [output_h, output_g]
if self.output_attentions: if self.output_attentions:
return output_h, output_g, attn_prob outputs = outputs + [attn_prob]
return outputs
return output_h, output_g
class XLNetFeedForward(nn.Module): class XLNetFeedForward(nn.Module):
def __init__(self, config): def __init__(self, config):
...@@ -867,7 +867,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -867,7 +867,7 @@ class XLNetModel(XLNetPreTrainedModel):
outputs = layer_module(output_h, output_g, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask, outputs = layer_module(output_h, output_g, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
r=pos_emb, seg_mat=seg_mat, mems=mems[i], target_mapping=target_mapping, r=pos_emb, seg_mat=seg_mat, mems=mems[i], target_mapping=target_mapping,
head_mask=head_mask) head_mask=head_mask[i])
output_h, output_g = outputs[:2] output_h, output_g = outputs[:2]
if self.output_attentions: if self.output_attentions:
attentions.append(outputs[2:]) attentions.append(outputs[2:])
......
...@@ -123,9 +123,13 @@ class BertModelTest(unittest.TestCase): ...@@ -123,9 +123,13 @@ class BertModelTest(unittest.TestCase):
def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertModel(config=config) model = BertModel(config=config)
model.eval() model.eval()
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) sequence_output, pooled_output = model(input_ids, token_type_ids, input_mask)
model = BertModel(config=config, output_hidden_states=True)
model.eval()
_, _, all_encoder_layers = model(input_ids, token_type_ids, input_mask)
outputs = { outputs = {
"sequence_output": all_encoder_layers[-1], "sequence_output": sequence_output,
"pooled_output": pooled_output, "pooled_output": pooled_output,
"all_encoder_layers": all_encoder_layers, "all_encoder_layers": all_encoder_layers,
} }
...@@ -134,7 +138,7 @@ class BertModelTest(unittest.TestCase): ...@@ -134,7 +138,7 @@ class BertModelTest(unittest.TestCase):
def check_bert_model_output(self, result): def check_bert_model_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
[size for layer in result["all_encoder_layers"] for size in layer.size()], [size for layer in result["all_encoder_layers"] for size in layer.size()],
[self.batch_size, self.seq_length, self.hidden_size] * self.num_hidden_layers) [self.batch_size, self.seq_length, self.hidden_size] * (self.num_hidden_layers + 1))
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size]) [self.batch_size, self.seq_length, self.hidden_size])
...@@ -144,8 +148,7 @@ class BertModelTest(unittest.TestCase): ...@@ -144,8 +148,7 @@ class BertModelTest(unittest.TestCase):
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMaskedLM(config=config) model = BertForMaskedLM(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels) loss, prediction_scores = model(input_ids, token_type_ids, input_mask, token_labels)
prediction_scores = model(input_ids, token_type_ids, input_mask)
outputs = { outputs = {
"loss": loss, "loss": loss,
"prediction_scores": prediction_scores, "prediction_scores": prediction_scores,
...@@ -160,8 +163,7 @@ class BertModelTest(unittest.TestCase): ...@@ -160,8 +163,7 @@ class BertModelTest(unittest.TestCase):
def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForNextSentencePrediction(config=config) model = BertForNextSentencePrediction(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels) loss, seq_relationship_score = model(input_ids, token_type_ids, input_mask, sequence_labels)
seq_relationship_score = model(input_ids, token_type_ids, input_mask)
outputs = { outputs = {
"loss": loss, "loss": loss,
"seq_relationship_score": seq_relationship_score, "seq_relationship_score": seq_relationship_score,
...@@ -177,8 +179,7 @@ class BertModelTest(unittest.TestCase): ...@@ -177,8 +179,7 @@ class BertModelTest(unittest.TestCase):
def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForPreTraining(config=config) model = BertForPreTraining(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels) loss, prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask)
outputs = { outputs = {
"loss": loss, "loss": loss,
"prediction_scores": prediction_scores, "prediction_scores": prediction_scores,
...@@ -198,8 +199,7 @@ class BertModelTest(unittest.TestCase): ...@@ -198,8 +199,7 @@ class BertModelTest(unittest.TestCase):
def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForQuestionAnswering(config=config) model = BertForQuestionAnswering(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels) loss, start_logits, end_logits = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
outputs = { outputs = {
"loss": loss, "loss": loss,
"start_logits": start_logits, "start_logits": start_logits,
...@@ -219,8 +219,7 @@ class BertModelTest(unittest.TestCase): ...@@ -219,8 +219,7 @@ class BertModelTest(unittest.TestCase):
def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForSequenceClassification(config=config, num_labels=self.num_labels) model = BertForSequenceClassification(config=config, num_labels=self.num_labels)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels) loss, logits = model(input_ids, token_type_ids, input_mask, sequence_labels)
logits = model(input_ids, token_type_ids, input_mask)
outputs = { outputs = {
"loss": loss, "loss": loss,
"logits": logits, "logits": logits,
...@@ -236,8 +235,7 @@ class BertModelTest(unittest.TestCase): ...@@ -236,8 +235,7 @@ class BertModelTest(unittest.TestCase):
def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForTokenClassification(config=config, num_labels=self.num_labels) model = BertForTokenClassification(config=config, num_labels=self.num_labels)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels) loss, logits = model(input_ids, token_type_ids, input_mask, token_labels)
logits = model(input_ids, token_type_ids, input_mask)
outputs = { outputs = {
"loss": loss, "loss": loss,
"logits": logits, "logits": logits,
...@@ -256,13 +254,10 @@ class BertModelTest(unittest.TestCase): ...@@ -256,13 +254,10 @@ class BertModelTest(unittest.TestCase):
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss = model(multiple_choice_inputs_ids, loss, logits = model(multiple_choice_inputs_ids,
multiple_choice_token_type_ids, multiple_choice_token_type_ids,
multiple_choice_input_mask, multiple_choice_input_mask,
choice_labels) choice_labels)
logits = model(multiple_choice_inputs_ids,
multiple_choice_token_type_ids,
multiple_choice_input_mask)
outputs = { outputs = {
"loss": loss, "loss": loss,
"logits": logits, "logits": logits,
...@@ -285,8 +280,8 @@ class BertModelTest(unittest.TestCase): ...@@ -285,8 +280,8 @@ class BertModelTest(unittest.TestCase):
else: else:
model = model_class(config=config, output_attentions=True) model = model_class(config=config, output_attentions=True)
model.eval() model.eval()
output = model(input_ids, token_type_ids, input_mask) outputs = model(input_ids, token_type_ids, input_mask)
attentions = output[0] attentions = outputs[-1]
self.parent.assertEqual(len(attentions), self.num_hidden_layers) self.parent.assertEqual(len(attentions), self.num_hidden_layers)
self.parent.assertListEqual( self.parent.assertListEqual(
list(attentions[0].size()), list(attentions[0].size()),
...@@ -300,57 +295,56 @@ class BertModelTest(unittest.TestCase): ...@@ -300,57 +295,56 @@ class BertModelTest(unittest.TestCase):
if model_class in [BertForSequenceClassification, if model_class in [BertForSequenceClassification,
BertForTokenClassification]: BertForTokenClassification]:
model = model_class(config=config, model = model_class(config=config,
num_labels=self.num_labels, num_labels=self.num_labels)
keep_multihead_output=True)
else: else:
model = model_class(config=config, keep_multihead_output=True) model = model_class(config=config)
model.eval() model.eval()
head_mask = torch.ones(self.num_hidden_layers, self.num_attention_heads).to(input_ids.device) head_mask = torch.ones(self.num_hidden_layers, self.num_attention_heads).to(input_ids.device)
head_mask[0, 1:-1] = 0.0 # Mask all but the first and last heads on the first layer head_mask[0, 1:-1] = 0.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 0.0 # Mask all but the first head on the last layer head_mask[-1, 1:] = 0.0 # Mask all but the first head on the last layer
output = model(input_ids, token_type_ids, input_mask, head_mask=head_mask) # Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask.requires_grad_(requires_grad=True)
outputs = model(input_ids, token_type_ids, input_mask, head_mask=head_mask)
if isinstance(model, BertModel): # Compute some gradients
output = sum(t.sum() for t in output[0]) output = sum(t.sum() for t in outputs[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output)
output = output.sum() output = output.sum()
output.backward() output.backward()
multihead_outputs = (model if isinstance(model, BertModel) else model.bert).get_multihead_outputs() multihead_outputs = head_mask.grad
self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers) self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
self.parent.assertListEqual( # self.parent.assertListEqual(
list(multihead_outputs[0].size()), # list(multihead_outputs[0].size()),
[self.batch_size, self.num_attention_heads, # [self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads]) # self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertEqual( # self.parent.assertEqual(
len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()), # len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
0) # 0)
self.parent.assertEqual( # self.parent.assertEqual(
len(multihead_outputs[0][:, 0, :, :].nonzero()), # len(multihead_outputs[0][:, 0, :, :].nonzero()),
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads) # self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
self.parent.assertEqual( # self.parent.assertEqual(
len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()), # len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads) # self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
self.parent.assertListEqual( # self.parent.assertListEqual(
list(multihead_outputs[1].size()), # list(multihead_outputs[1].size()),
[self.batch_size, self.num_attention_heads, # [self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads]) # self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertEqual( # self.parent.assertEqual(
len(multihead_outputs[1].nonzero()), # len(multihead_outputs[1].nonzero()),
multihead_outputs[1].numel()) # multihead_outputs[1].numel())
self.parent.assertListEqual( # self.parent.assertListEqual(
list(multihead_outputs[-1].size()), # list(multihead_outputs[-1].size()),
[self.batch_size, self.num_attention_heads, # [self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads]) # self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertEqual( # self.parent.assertEqual(
len(multihead_outputs[-1][:, 1:, :, :].nonzero()), # len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
0) # 0)
self.parent.assertEqual( # self.parent.assertEqual(
len(multihead_outputs[-1][:, 0, :, :].nonzero()), # len(multihead_outputs[-1][:, 0, :, :].nonzero()),
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads) # self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def create_and_check_bert_for_head_pruning(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_bert_for_head_pruning(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
...@@ -360,38 +354,34 @@ class BertModelTest(unittest.TestCase): ...@@ -360,38 +354,34 @@ class BertModelTest(unittest.TestCase):
if model_class in [BertForSequenceClassification, if model_class in [BertForSequenceClassification,
BertForTokenClassification]: BertForTokenClassification]:
model = model_class(config=config, model = model_class(config=config,
num_labels=self.num_labels, num_labels=self.num_labels)
keep_multihead_output=True)
else: else:
model = model_class(config=config, keep_multihead_output=True) model = model_class(config=config)
model.eval() model.eval()
bert_model = model if isinstance(model, BertModel) else model.bert bert_model = model if isinstance(model, BertModel) else model.bert
heads_to_prune = {0: list(range(1, self.num_attention_heads)), heads_to_prune = {0: list(range(1, self.num_attention_heads)),
-1: [0]} -1: [0]}
bert_model.prune_heads(heads_to_prune) bert_model.prune_heads(heads_to_prune)
output = model(input_ids, token_type_ids, input_mask) outputs = model(input_ids, token_type_ids, input_mask)
if isinstance(model, BertModel): # output = sum(t.sum() for t in outputs[0])
output = sum(t.sum() for t in output[0]) # output = output.sum()
elif isinstance(output, (list, tuple)): # output.backward()
output = sum(t.sum() for t in output) # multihead_outputs = bert_model.get_multihead_outputs()
output = output.sum()
output.backward() # self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
multihead_outputs = bert_model.get_multihead_outputs() # self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers) # [self.batch_size, 1,
self.parent.assertListEqual( # self.seq_length, self.hidden_size // self.num_attention_heads])
list(multihead_outputs[0].size()), # self.parent.assertListEqual(
[self.batch_size, 1, # list(multihead_outputs[1].size()),
self.seq_length, self.hidden_size // self.num_attention_heads]) # [self.batch_size, self.num_attention_heads,
self.parent.assertListEqual( # self.seq_length, self.hidden_size // self.num_attention_heads])
list(multihead_outputs[1].size()), # self.parent.assertListEqual(
[self.batch_size, self.num_attention_heads, # list(multihead_outputs[-1].size()),
self.seq_length, self.hidden_size // self.num_attention_heads]) # [self.batch_size, self.num_attention_heads-1,
self.parent.assertListEqual( # self.seq_length, self.hidden_size // self.num_attention_heads])
list(multihead_outputs[-1].size()),
[self.batch_size, self.num_attention_heads-1,
self.seq_length, self.hidden_size // self.num_attention_heads])
def test_default(self): def test_default(self):
......
...@@ -134,26 +134,19 @@ class XLNetModelTest(unittest.TestCase): ...@@ -134,26 +134,19 @@ class XLNetModelTest(unittest.TestCase):
model = XLNetLMHeadModel(config) model = XLNetLMHeadModel(config)
model.eval() model.eval()
loss_1, mems_1a = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels) loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
all_logits_1, mems_1b = model(input_ids_1, token_type_ids=segment_ids)
loss_2, mems_2a = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1a) loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
all_logits_2, mems_2b = model(input_ids_2, token_type_ids=segment_ids, mems=mems_1b)
logits, _ = model(input_ids_q, logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping, inp_q=inp_q)
perm_mask=perm_mask,
target_mapping=target_mapping,
inp_q=inp_q)
outputs = { outputs = {
"loss_1": loss_1, "loss_1": loss_1,
"mems_1a": mems_1a, "mems_1": mems_1,
"all_logits_1": all_logits_1, "all_logits_1": all_logits_1,
"mems_1b": mems_1b,
"loss_2": loss_2, "loss_2": loss_2,
"mems_2a": mems_2a, "mems_2": mems_2,
"all_logits_2": all_logits_2, "all_logits_2": all_logits_2,
"mems_2b": mems_2b,
} }
return outputs return outputs
...@@ -165,14 +158,8 @@ class XLNetModelTest(unittest.TestCase): ...@@ -165,14 +158,8 @@ class XLNetModelTest(unittest.TestCase):
list(result["all_logits_1"].size()), list(result["all_logits_1"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1a"]), list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.d_model]] * self.n_layer) [[self.seq_length, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1b"]),
[[self.seq_length, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1b"]))
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss_2"].size()), list(result["loss_2"].size()),
...@@ -181,14 +168,8 @@ class XLNetModelTest(unittest.TestCase): ...@@ -181,14 +168,8 @@ class XLNetModelTest(unittest.TestCase):
list(result["all_logits_2"].size()), list(result["all_logits_2"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2a"]), list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) [[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2b"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2b"]))
def test_default(self): def test_default(self):
self.run_tester(XLNetModelTest.XLNetModelTester(self)) self.run_tester(XLNetModelTest.XLNetModelTester(self))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment