Unverified Commit d4306dae authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `AlignModelTest` tests (#21923)



* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent c5a1ff9e
...@@ -65,7 +65,7 @@ class AlignVisionModelTester: ...@@ -65,7 +65,7 @@ class AlignVisionModelTester:
def __init__( def __init__(
self, self,
parent, parent,
batch_size=13, batch_size=12,
image_size=32, image_size=32,
num_channels=3, num_channels=3,
kernel_sizes=[3, 3, 5], kernel_sizes=[3, 3, 5],
...@@ -234,7 +234,7 @@ class AlignTextModelTester: ...@@ -234,7 +234,7 @@ class AlignTextModelTester:
def __init__( def __init__(
self, self,
parent, parent,
batch_size=13, batch_size=12,
seq_length=7, seq_length=7,
is_training=True, is_training=True,
use_input_mask=True, use_input_mask=True,
...@@ -521,6 +521,15 @@ class AlignModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -521,6 +521,15 @@ class AlignModelTest(ModelTesterMixin, unittest.TestCase):
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict() loaded_model_state_dict = loaded_model.state_dict()
non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]
loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
models_equal = True models_equal = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment