Unverified Commit d6e92044 authored by Partho's avatar Partho Committed by GitHub
Browse files

wrap forward passes with torch.no_grad() (#19274)

parent 2403dbd6
......@@ -444,7 +444,8 @@ class ConvBertModelIntegrationTest(unittest.TestCase):
def test_inference_no_head(self):
model = ConvBertModel.from_pretrained("YituTech/conv-bert-base")
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6]])
output = model(input_ids)[0]
with torch.no_grad():
output = model(input_ids)[0]
expected_shape = torch.Size((1, 6, 768))
self.assertEqual(output.shape, expected_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment