"tools/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "f1105409c29eccffe5a556c62362bbe216244065"
Commit f2538c12 authored by thomwolf's avatar thomwolf
Browse files

all tests in torch no grad

parent a5df980c
...@@ -120,7 +120,9 @@ class CommonTestCases: ...@@ -120,7 +120,9 @@ class CommonTestCases:
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
first, second = model(**inputs_dict)[0], model(**inputs_dict)[0] with torch.no_grad():
first = model(**inputs_dict)[0]
second = model(**inputs_dict)[0]
out_1 = first.cpu().numpy() out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy() out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)] out_1 = out_1[~np.isnan(out_1)]
...@@ -142,7 +144,8 @@ class CommonTestCases: ...@@ -142,7 +144,8 @@ class CommonTestCases:
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(**inputs_dict) with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(model.config.output_attentions, True) self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, False) self.assertEqual(model.config.output_hidden_states, False)
...@@ -173,7 +176,8 @@ class CommonTestCases: ...@@ -173,7 +176,8 @@ class CommonTestCases:
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(**inputs_dict) with torch.no_grad():
outputs = model(**inputs_dict)
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs)) self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_attentions, True) self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config.output_hidden_states, True)
...@@ -273,7 +277,8 @@ class CommonTestCases: ...@@ -273,7 +277,8 @@ class CommonTestCases:
inputs = inputs_dict.copy() inputs = inputs_dict.copy()
inputs['head_mask'] = head_mask inputs['head_mask'] = head_mask
outputs = model(**inputs) with torch.no_grad():
outputs = model(**inputs)
# Test that we can get a gradient back for importance score computation # Test that we can get a gradient back for importance score computation
output = sum(t.sum() for t in outputs[0]) output = sum(t.sum() for t in outputs[0])
...@@ -320,7 +325,8 @@ class CommonTestCases: ...@@ -320,7 +325,8 @@ class CommonTestCases:
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0]} -1: [0]}
model.prune_heads(heads_to_prune) model.prune_heads(heads_to_prune)
outputs = model(**inputs_dict) with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1] attentions = outputs[-1]
...@@ -356,7 +362,8 @@ class CommonTestCases: ...@@ -356,7 +362,8 @@ class CommonTestCases:
model = model_class.from_pretrained(directory) model = model_class.from_pretrained(directory)
model.to(torch_device) model.to(torch_device)
outputs = model(**inputs_dict) with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], 1) self.assertEqual(attentions[0].shape[-3], 1)
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads) self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
...@@ -385,7 +392,8 @@ class CommonTestCases: ...@@ -385,7 +392,8 @@ class CommonTestCases:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(**inputs_dict) with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], 1) self.assertEqual(attentions[0].shape[-3], 1)
...@@ -412,7 +420,8 @@ class CommonTestCases: ...@@ -412,7 +420,8 @@ class CommonTestCases:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(**inputs_dict) with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1) self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
...@@ -429,7 +438,8 @@ class CommonTestCases: ...@@ -429,7 +438,8 @@ class CommonTestCases:
model.to(torch_device) model.to(torch_device)
shutil.rmtree(directory) shutil.rmtree(directory)
outputs = model(**inputs_dict) with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1) self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
...@@ -440,7 +450,8 @@ class CommonTestCases: ...@@ -440,7 +450,8 @@ class CommonTestCases:
heads_to_prune = {0: [0], 2: [1, 2]} heads_to_prune = {0: [0], 2: [1, 2]}
model.prune_heads(heads_to_prune) model.prune_heads(heads_to_prune)
outputs = model(**inputs_dict) with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads -1) self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads -1)
...@@ -459,7 +470,8 @@ class CommonTestCases: ...@@ -459,7 +470,8 @@ class CommonTestCases:
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(**inputs_dict) with torch.no_grad():
outputs = model(**inputs_dict)
hidden_states = outputs[-1] hidden_states = outputs[-1]
self.assertEqual(model.config.output_attentions, False) self.assertEqual(model.config.output_attentions, False)
self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config.output_hidden_states, True)
...@@ -594,7 +606,8 @@ class CommonTestCases: ...@@ -594,7 +606,8 @@ class CommonTestCases:
inputs_dict["encoder_inputs_embeds"] = wte(encoder_input_ids) inputs_dict["encoder_inputs_embeds"] = wte(encoder_input_ids)
inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids) inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids)
outputs = model(**inputs_dict) with torch.no_grad():
outputs = model(**inputs_dict)
class GPTModelTester(CommonModelTester): class GPTModelTester(CommonModelTester):
...@@ -682,9 +695,10 @@ class CommonTestCases: ...@@ -682,9 +695,10 @@ class CommonTestCases:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids, position_ids, token_type_ids) with torch.no_grad():
outputs = model(input_ids, position_ids) outputs = model(input_ids, position_ids, token_type_ids)
outputs = model(input_ids) outputs = model(input_ids, position_ids)
outputs = model(input_ids)
hidden_state = outputs[0] hidden_state = outputs[0]
self.parent.assertListEqual( self.parent.assertListEqual(
...@@ -697,7 +711,8 @@ class CommonTestCases: ...@@ -697,7 +711,8 @@ class CommonTestCases:
model = self.lm_head_model_class(config) model = self.lm_head_model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids, position_ids, token_type_ids, lm_labels) with torch.no_grad():
outputs = model(input_ids, position_ids, token_type_ids, lm_labels)
loss, lm_logits = outputs[:2] loss, lm_logits = outputs[:2]
total_voc = self.vocab_size total_voc = self.vocab_size
...@@ -714,7 +729,8 @@ class CommonTestCases: ...@@ -714,7 +729,8 @@ class CommonTestCases:
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids) with torch.no_grad():
outputs = model(input_ids)
presents = outputs[-1] presents = outputs[-1]
self.parent.assertEqual(self.num_hidden_layers, len(presents)) self.parent.assertEqual(self.num_hidden_layers, len(presents))
self.parent.assertListEqual( self.parent.assertListEqual(
...@@ -727,7 +743,8 @@ class CommonTestCases: ...@@ -727,7 +743,8 @@ class CommonTestCases:
model = self.double_head_model_class(config) model = self.double_head_model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids, mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels, with torch.no_grad():
outputs = model(input_ids, mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels,
token_type_ids=token_type_ids, position_ids=position_ids) token_type_ids=token_type_ids, position_ids=position_ids)
lm_loss, mc_loss, lm_logits, mc_logits = outputs[:4] lm_loss, mc_loss, lm_logits, mc_logits = outputs[:4]
loss = [lm_loss, mc_loss] loss = [lm_loss, mc_loss]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment