Commit 286d5bb6 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Use a random temp dir for writing pruned models in tests.

parent 478e456e
...@@ -353,11 +353,10 @@ class CommonTestCases: ...@@ -353,11 +353,10 @@ class CommonTestCases:
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0]} -1: [0]}
model.prune_heads(heads_to_prune) model.prune_heads(heads_to_prune)
directory = "pruned_model"
if not os.path.exists(directory): with TemporaryDirectory() as temp_dir_name:
os.makedirs(directory) model.save_pretrained(temp_dir_name)
model.save_pretrained(directory) model = model_class.from_pretrained(temp_dir_name)
model = model_class.from_pretrained(directory)
model.to(torch_device) model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
...@@ -367,7 +366,6 @@ class CommonTestCases: ...@@ -367,7 +366,6 @@ class CommonTestCases:
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads) self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
shutil.rmtree(directory)
def test_head_pruning_save_load_from_config_init(self): def test_head_pruning_save_load_from_config_init(self):
if not self.test_pruning: if not self.test_pruning:
...@@ -427,14 +425,10 @@ class CommonTestCases: ...@@ -427,14 +425,10 @@ class CommonTestCases:
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads) self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads) self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
directory = "pruned_model" with TemporaryDirectory() as temp_dir_name:
model.save_pretrained(temp_dir_name)
if not os.path.exists(directory): model = model_class.from_pretrained(temp_dir_name)
os.makedirs(directory)
model.save_pretrained(directory)
model = model_class.from_pretrained(directory)
model.to(torch_device) model.to(torch_device)
shutil.rmtree(directory)
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment