Commit d9fa1bad authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Fix failing torchscript test for xlnet

model.parameters() order is apparently not stable (only for xlnet, for some reason)
parent cd51893d
......@@ -236,11 +236,14 @@ class ModelTesterMixin:
loaded_model.to(torch_device)
loaded_model.eval()
model_params = model.parameters()
loaded_model_params = loaded_model.parameters()
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
models_equal = True
for p1, p2 in zip(model_params, loaded_model_params):
for layer_name, p1 in model_state_dict.items():
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment