Unverified Commit 688c3e8e authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Update `max_diff` in `test_save_load_fast_init_to_base` (#19849)



* Fix test_save_load_fast_init_to_base

* Fix test_save_load_fast_init_to_base

* update
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 7829c890
...@@ -398,7 +398,9 @@ class ModelTesterMixin: ...@@ -398,7 +398,9 @@ class ModelTesterMixin:
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False) model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys(): for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() max_diff = torch.max(
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
).item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_initialization(self): def test_initialization(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment