Commit 67a8be8e authored by thomwolf's avatar thomwolf
Browse files

fix backward in tests

parent f2538c12
......@@ -277,8 +277,7 @@ class CommonTestCases:
inputs = inputs_dict.copy()
inputs['head_mask'] = head_mask
with torch.no_grad():
outputs = model(**inputs)
outputs = model(**inputs)
# Test that we can get a gradient back for importance score computation
output = sum(t.sum() for t in outputs[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment