"docs/vscode:/vscode.git/clone" did not exist on "e4e35296fb4a6af49c18dc814629a6acfdbe96a2"
Commit f2538c12 authored by thomwolf's avatar thomwolf
Browse files

all tests in torch no grad

parent a5df980c
......@@ -120,7 +120,9 @@ class CommonTestCases:
model = model_class(config)
model.to(torch_device)
model.eval()
first, second = model(**inputs_dict)[0], model(**inputs_dict)[0]
with torch.no_grad():
first = model(**inputs_dict)[0]
second = model(**inputs_dict)[0]
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
......@@ -142,6 +144,7 @@ class CommonTestCases:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(model.config.output_attentions, True)
......@@ -173,6 +176,7 @@ class CommonTestCases:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**inputs_dict)
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_attentions, True)
......@@ -273,6 +277,7 @@ class CommonTestCases:
inputs = inputs_dict.copy()
inputs['head_mask'] = head_mask
with torch.no_grad():
outputs = model(**inputs)
# Test that we can get a gradient back for importance score computation
......@@ -320,6 +325,7 @@ class CommonTestCases:
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0]}
model.prune_heads(heads_to_prune)
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
......@@ -356,6 +362,7 @@ class CommonTestCases:
model = model_class.from_pretrained(directory)
model.to(torch_device)
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], 1)
......@@ -385,6 +392,7 @@ class CommonTestCases:
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
......@@ -412,6 +420,7 @@ class CommonTestCases:
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
......@@ -429,6 +438,7 @@ class CommonTestCases:
model.to(torch_device)
shutil.rmtree(directory)
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
......@@ -440,6 +450,7 @@ class CommonTestCases:
heads_to_prune = {0: [0], 2: [1, 2]}
model.prune_heads(heads_to_prune)
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
......@@ -459,6 +470,7 @@ class CommonTestCases:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**inputs_dict)
hidden_states = outputs[-1]
self.assertEqual(model.config.output_attentions, False)
......@@ -594,6 +606,7 @@ class CommonTestCases:
inputs_dict["encoder_inputs_embeds"] = wte(encoder_input_ids)
inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids)
with torch.no_grad():
outputs = model(**inputs_dict)
class GPTModelTester(CommonModelTester):
......@@ -682,6 +695,7 @@ class CommonTestCases:
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(input_ids, position_ids, token_type_ids)
outputs = model(input_ids, position_ids)
outputs = model(input_ids)
......@@ -697,6 +711,7 @@ class CommonTestCases:
model = self.lm_head_model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(input_ids, position_ids, token_type_ids, lm_labels)
loss, lm_logits = outputs[:2]
......@@ -714,6 +729,7 @@ class CommonTestCases:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(input_ids)
presents = outputs[-1]
self.parent.assertEqual(self.num_hidden_layers, len(presents))
......@@ -727,6 +743,7 @@ class CommonTestCases:
model = self.double_head_model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(input_ids, mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels,
token_type_ids=token_type_ids, position_ids=position_ids)
lm_loss, mc_loss, lm_logits, mc_logits = outputs[:4]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment