Commit 7220d47a authored by thomwolf's avatar thomwolf
Browse files

adding head pruning and tests

parent 8415a38b
...@@ -51,12 +51,11 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -51,12 +51,11 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
BERT_CONFIG_NAME = 'bert_config.json' BERT_CONFIG_NAME = 'bert_config.json'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
def prune_linear_layer(layer, index, dim=-1): def prune_linear_layer(layer, index, dim=0):
""" Prune a linear layer (a model parameters) to keep only entries in index. """ Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True. Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads. Used to remove heads.
""" """
dim = (dim+100) % 2
index = index.to(layer.weight.device) index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach() W = layer.weight.index_select(dim, index).clone().detach()
if layer.bias is not None: if layer.bias is not None:
...@@ -394,7 +393,7 @@ class BertAttention(nn.Module): ...@@ -394,7 +393,7 @@ class BertAttention(nn.Module):
self.output = BertSelfOutput(config) self.output = BertSelfOutput(config)
def prune_heads(self, heads): def prune_heads(self, heads):
mask = torch.ones(self.self.n_heads, self.self.d_head) mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
for head in heads: for head in heads:
mask[head] = 0 mask[head] = 0
mask = mask.view(-1).contiguous().eq(1) mask = mask.view(-1).contiguous().eq(1)
...@@ -403,7 +402,7 @@ class BertAttention(nn.Module): ...@@ -403,7 +402,7 @@ class BertAttention(nn.Module):
self.self.query = prune_linear_layer(self.self.query, index) self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index) self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index) self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=0) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params # Update hyper params
self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
......
...@@ -334,6 +334,47 @@ class BertModelTest(unittest.TestCase): ...@@ -334,6 +334,47 @@ class BertModelTest(unittest.TestCase):
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads) self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def create_and_check_bert_for_head_pruning(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config,
num_labels=self.num_labels,
keep_multihead_output=True)
else:
model = model_class(config=config, keep_multihead_output=True)
model.eval()
bert_model = model if isinstance(model, BertModel) else model.bert
heads_to_prune = {0: list(range(1, self.num_attention_heads)),
-1: [0]}
bert_model.prune_heads(heads_to_prune)
output = model(input_ids, token_type_ids, input_mask)
if isinstance(model, BertModel):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output)
output = output.sum()
output.backward()
multihead_outputs = bert_model.get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size, 1,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size, self.num_attention_heads-1,
self.seq_length, self.hidden_size // self.num_attention_heads])
def test_default(self): def test_default(self):
self.run_tester(BertModelTest.BertModelTester(self)) self.run_tester(BertModelTest.BertModelTester(self))
...@@ -394,6 +435,7 @@ class BertModelTest(unittest.TestCase): ...@@ -394,6 +435,7 @@ class BertModelTest(unittest.TestCase):
tester.create_and_check_bert_for_attentions(*config_and_inputs) tester.create_and_check_bert_for_attentions(*config_and_inputs)
tester.create_and_check_bert_for_headmasking(*config_and_inputs) tester.create_and_check_bert_for_headmasking(*config_and_inputs)
tester.create_and_check_bert_for_head_pruning(*config_and_inputs)
@classmethod @classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None): def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment