Commit f2538c12 authored by thomwolf's avatar thomwolf
Browse files

all tests in torch no grad

parent a5df980c
...@@ -120,7 +120,9 @@ class CommonTestCases: ...@@ -120,7 +120,9 @@ class CommonTestCases:
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
first, second = model(**inputs_dict)[0], model(**inputs_dict)[0] with torch.no_grad():
first = model(**inputs_dict)[0]
second = model(**inputs_dict)[0]
out_1 = first.cpu().numpy() out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy() out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)] out_1 = out_1[~np.isnan(out_1)]
...@@ -142,6 +144,7 @@ class CommonTestCases: ...@@ -142,6 +144,7 @@ class CommonTestCases:
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(model.config.output_attentions, True) self.assertEqual(model.config.output_attentions, True)
...@@ -173,6 +176,7 @@ class CommonTestCases: ...@@ -173,6 +176,7 @@ class CommonTestCases:
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs)) self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_attentions, True) self.assertEqual(model.config.output_attentions, True)
...@@ -273,6 +277,7 @@ class CommonTestCases: ...@@ -273,6 +277,7 @@ class CommonTestCases:
inputs = inputs_dict.copy() inputs = inputs_dict.copy()
inputs['head_mask'] = head_mask inputs['head_mask'] = head_mask
with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
# Test that we can get a gradient back for importance score computation # Test that we can get a gradient back for importance score computation
...@@ -320,6 +325,7 @@ class CommonTestCases: ...@@ -320,6 +325,7 @@ class CommonTestCases:
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0]} -1: [0]}
model.prune_heads(heads_to_prune) model.prune_heads(heads_to_prune)
with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
attentions = outputs[-1] attentions = outputs[-1]
...@@ -356,6 +362,7 @@ class CommonTestCases: ...@@ -356,6 +362,7 @@ class CommonTestCases:
model = model_class.from_pretrained(directory) model = model_class.from_pretrained(directory)
model.to(torch_device) model.to(torch_device)
with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], 1) self.assertEqual(attentions[0].shape[-3], 1)
...@@ -385,6 +392,7 @@ class CommonTestCases: ...@@ -385,6 +392,7 @@ class CommonTestCases:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
attentions = outputs[-1] attentions = outputs[-1]
...@@ -412,6 +420,7 @@ class CommonTestCases: ...@@ -412,6 +420,7 @@ class CommonTestCases:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
attentions = outputs[-1] attentions = outputs[-1]
...@@ -429,6 +438,7 @@ class CommonTestCases: ...@@ -429,6 +438,7 @@ class CommonTestCases:
model.to(torch_device) model.to(torch_device)
shutil.rmtree(directory) shutil.rmtree(directory)
with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
attentions = outputs[-1] attentions = outputs[-1]
...@@ -440,6 +450,7 @@ class CommonTestCases: ...@@ -440,6 +450,7 @@ class CommonTestCases:
heads_to_prune = {0: [0], 2: [1, 2]} heads_to_prune = {0: [0], 2: [1, 2]}
model.prune_heads(heads_to_prune) model.prune_heads(heads_to_prune)
with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
attentions = outputs[-1] attentions = outputs[-1]
...@@ -459,6 +470,7 @@ class CommonTestCases: ...@@ -459,6 +470,7 @@ class CommonTestCases:
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
hidden_states = outputs[-1] hidden_states = outputs[-1]
self.assertEqual(model.config.output_attentions, False) self.assertEqual(model.config.output_attentions, False)
...@@ -594,6 +606,7 @@ class CommonTestCases: ...@@ -594,6 +606,7 @@ class CommonTestCases:
inputs_dict["encoder_inputs_embeds"] = wte(encoder_input_ids) inputs_dict["encoder_inputs_embeds"] = wte(encoder_input_ids)
inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids) inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids)
with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
class GPTModelTester(CommonModelTester): class GPTModelTester(CommonModelTester):
...@@ -682,6 +695,7 @@ class CommonTestCases: ...@@ -682,6 +695,7 @@ class CommonTestCases:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad():
outputs = model(input_ids, position_ids, token_type_ids) outputs = model(input_ids, position_ids, token_type_ids)
outputs = model(input_ids, position_ids) outputs = model(input_ids, position_ids)
outputs = model(input_ids) outputs = model(input_ids)
...@@ -697,6 +711,7 @@ class CommonTestCases: ...@@ -697,6 +711,7 @@ class CommonTestCases:
model = self.lm_head_model_class(config) model = self.lm_head_model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad():
outputs = model(input_ids, position_ids, token_type_ids, lm_labels) outputs = model(input_ids, position_ids, token_type_ids, lm_labels)
loss, lm_logits = outputs[:2] loss, lm_logits = outputs[:2]
...@@ -714,6 +729,7 @@ class CommonTestCases: ...@@ -714,6 +729,7 @@ class CommonTestCases:
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad():
outputs = model(input_ids) outputs = model(input_ids)
presents = outputs[-1] presents = outputs[-1]
self.parent.assertEqual(self.num_hidden_layers, len(presents)) self.parent.assertEqual(self.num_hidden_layers, len(presents))
...@@ -727,6 +743,7 @@ class CommonTestCases: ...@@ -727,6 +743,7 @@ class CommonTestCases:
model = self.double_head_model_class(config) model = self.double_head_model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad():
outputs = model(input_ids, mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels, outputs = model(input_ids, mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels,
token_type_ids=token_type_ids, position_ids=position_ids) token_type_ids=token_type_ids, position_ids=position_ids)
lm_loss, mc_loss, lm_logits, mc_logits = outputs[:4] lm_loss, mc_loss, lm_logits, mc_logits = outputs[:4]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment