Unverified Commit a785992c authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] fix more sharding tests (#8797)

* fix

* fix

* ugly

* okay

* fix more

* fix oops
parent 35cc66dc
...@@ -885,11 +885,11 @@ class ModelTesterMixin: ...@@ -885,11 +885,11 @@ class ModelTesterMixin:
@require_torch_gpu @require_torch_gpu
def test_sharded_checkpoints(self): def test_sharded_checkpoints(self):
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""] model_size = compute_module_sizes(model)[""]
...@@ -909,7 +909,8 @@ class ModelTesterMixin: ...@@ -909,7 +909,8 @@ class ModelTesterMixin:
new_model = new_model.to(torch_device) new_model = new_model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
...@@ -942,7 +943,8 @@ class ModelTesterMixin: ...@@ -942,7 +943,8 @@ class ModelTesterMixin:
new_model = new_model.to(torch_device) new_model = new_model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment