"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6b1ff250842f52136d5159bb67a26b50ba01485d"
Commit f2538c12 authored by thomwolf's avatar thomwolf
Browse files

all tests in torch no grad

parent a5df980c
......@@ -120,7 +120,9 @@ class CommonTestCases:
model = model_class(config)
model.to(torch_device)
model.eval()
first, second = model(**inputs_dict)[0], model(**inputs_dict)[0]
with torch.no_grad():
first = model(**inputs_dict)[0]
second = model(**inputs_dict)[0]
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
......@@ -142,7 +144,8 @@ class CommonTestCases:
model = model_class(config)
model.to(torch_device)
model.eval()
outputs = model(**inputs_dict)
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, False)
......@@ -173,7 +176,8 @@ class CommonTestCases:
model = model_class(config)
model.to(torch_device)
model.eval()
outputs = model(**inputs_dict)
with torch.no_grad():
outputs = model(**inputs_dict)
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True)
......@@ -273,7 +277,8 @@ class CommonTestCases:
inputs = inputs_dict.copy()
inputs['head_mask'] = head_mask
outputs = model(**inputs)
with torch.no_grad():
outputs = model(**inputs)
# Test that we can get a gradient back for importance score computation
output = sum(t.sum() for t in outputs[0])
......@@ -320,7 +325,8 @@ class CommonTestCases:
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0]}
model.prune_heads(heads_to_prune)
outputs = model(**inputs_dict)
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
......@@ -356,7 +362,8 @@ class CommonTestCases:
model = model_class.from_pretrained(directory)
model.to(torch_device)
outputs = model(**inputs_dict)
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], 1)
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
......@@ -385,7 +392,8 @@ class CommonTestCases:
model.to(torch_device)
model.eval()
outputs = model(**inputs_dict)
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], 1)
......@@ -412,7 +420,8 @@ class CommonTestCases:
model.to(torch_device)
model.eval()
outputs = model(**inputs_dict)
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
......@@ -429,7 +438,8 @@ class CommonTestCases:
model.to(torch_device)
shutil.rmtree(directory)
outputs = model(**inputs_dict)
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
......@@ -440,7 +450,8 @@ class CommonTestCases:
heads_to_prune = {0: [0], 2: [1, 2]}
model.prune_heads(heads_to_prune)
outputs = model(**inputs_dict)
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads -1)
......@@ -459,7 +470,8 @@ class CommonTestCases:
model = model_class(config)
model.to(torch_device)
model.eval()
outputs = model(**inputs_dict)
with torch.no_grad():
outputs = model(**inputs_dict)
hidden_states = outputs[-1]
self.assertEqual(model.config.output_attentions, False)
self.assertEqual(model.config.output_hidden_states, True)
......@@ -594,7 +606,8 @@ class CommonTestCases:
inputs_dict["encoder_inputs_embeds"] = wte(encoder_input_ids)
inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids)
outputs = model(**inputs_dict)
with torch.no_grad():
outputs = model(**inputs_dict)
class GPTModelTester(CommonModelTester):
......@@ -682,9 +695,10 @@ class CommonTestCases:
model.to(torch_device)
model.eval()
outputs = model(input_ids, position_ids, token_type_ids)
outputs = model(input_ids, position_ids)
outputs = model(input_ids)
with torch.no_grad():
outputs = model(input_ids, position_ids, token_type_ids)
outputs = model(input_ids, position_ids)
outputs = model(input_ids)
hidden_state = outputs[0]
self.parent.assertListEqual(
......@@ -697,7 +711,8 @@ class CommonTestCases:
model = self.lm_head_model_class(config)
model.to(torch_device)
model.eval()
outputs = model(input_ids, position_ids, token_type_ids, lm_labels)
with torch.no_grad():
outputs = model(input_ids, position_ids, token_type_ids, lm_labels)
loss, lm_logits = outputs[:2]
total_voc = self.vocab_size
......@@ -714,7 +729,8 @@ class CommonTestCases:
model = model_class(config)
model.to(torch_device)
model.eval()
outputs = model(input_ids)
with torch.no_grad():
outputs = model(input_ids)
presents = outputs[-1]
self.parent.assertEqual(self.num_hidden_layers, len(presents))
self.parent.assertListEqual(
......@@ -727,7 +743,8 @@ class CommonTestCases:
model = self.double_head_model_class(config)
model.to(torch_device)
model.eval()
outputs = model(input_ids, mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels,
with torch.no_grad():
outputs = model(input_ids, mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels,
token_type_ids=token_type_ids, position_ids=position_ids)
lm_loss, mc_loss, lm_logits, mc_logits = outputs[:4]
loss = [lm_loss, mc_loss]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment