"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "653379c094f462eea33760e3f3049ec2971ef8a3"
Unverified Commit 8ef62ec9 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Fix torchscript tests (#12336)

* Fix torchscript tests

* Better test

* Remove bogus print
parent aef3823e
...@@ -564,13 +564,34 @@ class ModelTesterMixin: ...@@ -564,13 +564,34 @@ class ModelTesterMixin:
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict() loaded_model_state_dict = loaded_model.state_dict()
non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]
loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
self.assertTrue(found_buffer)
model_buffers.pop(i)
models_equal = True models_equal = True
for layer_name, p1 in model_state_dict.items(): for layer_name, p1 in model_state_dict.items():
p2 = loaded_model_state_dict[layer_name] if layer_name in loaded_model_state_dict:
if p1.data.ne(p2.data).sum() > 0: p2 = loaded_model_state_dict[layer_name]
models_equal = False if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal) self.assertTrue(models_equal)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment