"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2ac5b9325ed3b54950c6c61fd5838ac6e55a9fe1"
Unverified Commit 870a9542 authored by Partho's avatar Partho Committed by GitHub
Browse files

wrap forward passes with torch.no_grad() (#19438)

parent 692c5be7
...@@ -457,7 +457,8 @@ class RoFormerModelIntegrationTest(unittest.TestCase): ...@@ -457,7 +457,8 @@ class RoFormerModelIntegrationTest(unittest.TestCase):
def test_inference_masked_lm(self): def test_inference_masked_lm(self):
model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base") model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0] with torch.no_grad():
output = model(input_ids)[0]
# TODO Replace vocab size # TODO Replace vocab size
vocab_size = 50000 vocab_size = 50000
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment