Unverified Commit ca257a06 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Fix torchscript tests (#13701)

parent 5b570754
...@@ -436,7 +436,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -436,7 +436,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt")) torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt"))
loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device) loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
......
...@@ -273,7 +273,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -273,7 +273,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt")) torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt"))
loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device) loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
......
...@@ -325,7 +325,12 @@ class FlaubertModelTester(object): ...@@ -325,7 +325,12 @@ class FlaubertModelTester(object):
choice_labels, choice_labels,
input_mask, input_mask,
) = config_and_inputs ) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths} inputs_dict = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"lengths": input_lengths,
"attention_mask": input_mask,
}
return config, inputs_dict return config, inputs_dict
...@@ -422,7 +427,7 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -422,7 +427,7 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt")) torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt"))
loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device) loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment