Commit 12726f85 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Remove redundant torch.jit.trace in tests.

This looks like it could be expensive, so don't run it twice.
parent ac1b449c
...@@ -218,12 +218,11 @@ class CommonTestCases: ...@@ -218,12 +218,11 @@ class CommonTestCases:
inputs = inputs_dict['input_ids'] # Let's keep only input_ids inputs = inputs_dict['input_ids'] # Let's keep only input_ids
try: try:
torch.jit.trace(model, inputs) traced_gpt2 = torch.jit.trace(model, inputs)
except RuntimeError: except RuntimeError:
self.fail("Couldn't trace module.") self.fail("Couldn't trace module.")
try: try:
traced_gpt2 = torch.jit.trace(model, inputs)
torch.jit.save(traced_gpt2, "traced_model.pt") torch.jit.save(traced_gpt2, "traced_model.pt")
except RuntimeError: except RuntimeError:
self.fail("Couldn't save module.") self.fail("Couldn't save module.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment