Commit 42e00cf9 authored by Lysandre's avatar Lysandre Committed by LysandreJik
Browse files

Pruning saved to configuration first try

parent d7a4c325
......@@ -649,6 +649,12 @@ class BertModel(BertPreTrainedModel):
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
for layer, heads in pruned_heads:
if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads:
self.prune_heads({int(layer): list(map(int, heads))})
self.apply(self.init_weights)
def _resize_token_embeddings(self, new_num_tokens):
......
......@@ -104,6 +104,7 @@ class PretrainedConfig(object):
self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.torchscript = kwargs.pop('torchscript', False)
self.pruned_heads = kwargs.pop('pruned_heads', {})
def save_pretrained(self, save_directory):
""" Save a configuration object to the directory `save_directory`, so that it
......@@ -363,6 +364,15 @@ class PreTrainedModel(nn.Module):
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
"""
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
for layer, heads in heads_to_prune.items():
if str(layer) not in self.config.pruned_heads:
self.config.pruned_heads[str(layer)] = heads
else:
for head in heads:
if head not in self.config.pruned_heads[str(layer)]:
self.config.pruned_heads[str(layer)].append(head)
base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory):
......
......@@ -219,6 +219,7 @@ class CommonTestCases:
del inputs_dict["head_mask"]
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_attentions = True
config.output_hidden_states = False
model = model_class(config=config)
......@@ -237,6 +238,61 @@ class CommonTestCases:
self.assertEqual(
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
def test_head_pruning_save_load_from_pretrained(self):
if not self.test_pruning:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_attentions = True
config.output_hidden_states = False
model = model_class(config=config)
model.eval()
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0]}
model.prune_heads(heads_to_prune)
directory = "pruned_model"
if not os.path.exists(directory):
os.makedirs(directory)
model.save_pretrained(directory)
model = model_class.from_pretrained(directory)
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(
attentions[0].shape[-3], 1)
self.assertEqual(
attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
shutil.rmtree(directory)
def test_head_pruning_save_load_from_config_init(self):
print(self.test_pruning)
if not self.test_pruning:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_attentions = True
config.output_hidden_states = False
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0]}
config.pruned_heads = heads_to_prune
model = model_class(config=config)
model.eval()
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(
attentions[0].shape[-3], 1)
self.assertEqual(
attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment