Commit 87747518 authored by LysandreJik's avatar LysandreJik
Browse files

Blocks deletion from already deleted heads. Necessary integration test.

Now raises a warning when a head to be deleted already has been deleted. An integration test verifying the total pipeline (-> from config -> save model -> load model -> additional head pruning) has been added.
parent 719cb373
......@@ -651,6 +651,7 @@ class BertModel(BertPreTrainedModel):
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
config.pruned_heads = {}
for layer, heads in pruned_heads:
if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads:
self.prune_heads({int(layer): list(map(int, heads))})
......
......@@ -455,6 +455,7 @@ class GPT2Model(GPT2PreTrainedModel):
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
config.pruned_heads = {}
for layer, heads in pruned_heads:
if self.h[int(layer)].attn.n_head == config.n_head:
self.prune_heads({int(layer): list(map(int, heads))})
......
......@@ -458,6 +458,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
config.pruned_heads = {}
for layer, heads in pruned_heads:
if self.h[int(layer)].attn.n_head == config.n_head:
self.prune_heads({int(layer): list(map(int, heads))})
......
......@@ -201,6 +201,10 @@ class PretrainedConfig(object):
# Load config
config = cls.from_json_file(resolved_config_file)
if hasattr(config, 'pruned_heads'):
config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
# Update config with kwargs if needed
to_remove = []
for key, value in kwargs.items():
......@@ -365,15 +369,22 @@ class PreTrainedModel(nn.Module):
"""
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
to_be_pruned = {}
for layer, heads in heads_to_prune.items():
if str(layer) not in self.config.pruned_heads:
self.config.pruned_heads[str(layer)] = heads
if int(layer) not in self.config.pruned_heads:
self.config.pruned_heads[int(layer)] = heads
to_be_pruned[int(layer)] = heads
else:
for head in heads:
if head not in self.config.pruned_heads[str(layer)]:
self.config.pruned_heads[str(layer)].append(head)
base_model._prune_heads(heads_to_prune)
if head not in self.config.pruned_heads[int(layer)]:
self.config.pruned_heads[int(layer)].append(head)
to_be_pruned[int(layer)].append(head)
else:
logger.warning(f"Tried to remove head {head} of layer {layer} but it was already removed. "
f"The removed heads are {heads_to_prune}")
base_model._prune_heads(to_be_pruned)
def save_pretrained(self, save_directory):
""" Save a model and its configuration file to a directory, so that it
......
......@@ -561,6 +561,7 @@ class XLMModel(XLMPreTrainedModel):
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
config.pruned_heads = {}
for layer, heads in pruned_heads:
if self.attentions[int(layer)].n_heads == config.n_heads:
self.prune_heads({int(layer): list(map(int, heads))})
......
......@@ -262,12 +262,9 @@ class CommonTestCases:
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(
attentions[0].shape[-3], 1)
self.assertEqual(
attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
self.assertEqual(attentions[0].shape[-3], 1)
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
shutil.rmtree(directory)
......@@ -293,12 +290,67 @@ class CommonTestCases:
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(
attentions[0].shape[-3], 1)
self.assertEqual(
attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
self.assertEqual(attentions[0].shape[-3], 1)
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
def test_head_pruning_integration(self):
if not self.test_pruning:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
config.output_attentions = True
config.output_hidden_states = False
heads_to_prune = {0: [0], 1: [1, 2]}
config.pruned_heads = heads_to_prune
model = model_class(config=config)
model.eval()
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
directory = "pruned_model"
if not os.path.exists(directory):
os.makedirs(directory)
model.save_pretrained(directory)
model = model_class.from_pretrained(directory)
shutil.rmtree(directory)
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
heads_to_prune = {0: [0], 2: [1, 2]}
model.prune_heads(heads_to_prune)
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads -1)
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads - 2)
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]})
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment