Unverified Commit d533c7e9 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

[fix] T5 ONNX test: model.to(torch_device) (#5769)


Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>
parent d0486c8b
...@@ -336,7 +336,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -336,7 +336,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs[0].return_tuple = True config_and_inputs[0].return_tuple = True
model = T5Model(config_and_inputs[0]) model = T5Model(config_and_inputs[0]).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
torch.onnx.export( torch.onnx.export(
model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9, model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment