Commit d9fa1bad authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Fix failing torchscript test for xlnet

model.parameters() order is apparently not stable (only for xlnet, for some reason)
parent cd51893d
...@@ -236,11 +236,14 @@ class ModelTesterMixin: ...@@ -236,11 +236,14 @@ class ModelTesterMixin:
loaded_model.to(torch_device) loaded_model.to(torch_device)
loaded_model.eval() loaded_model.eval()
model_params = model.parameters() model_state_dict = model.state_dict()
loaded_model_params = loaded_model.parameters() loaded_model_state_dict = loaded_model.state_dict()
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
models_equal = True models_equal = True
for p1, p2 in zip(model_params, loaded_model_params): for layer_name, p1 in model_state_dict.items():
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0: if p1.data.ne(p2.data).sum() > 0:
models_equal = False models_equal = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment