"vscode:/vscode.git/clone" did not exist on "d76af4d4c12bc0fe121b4530102f8f15964e98cc"
Unverified Commit 31adeb41 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] fix sharding tests (#8764)

fix sharding tests
parent a7b9634e
......@@ -415,9 +415,10 @@ class GaussianFourierProjection(nn.Module):
if set_W_to_weight:
# to delete later
del self.weight
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.weight = self.W
del self.W
def forward(self, x):
if self.log:
......
......@@ -361,9 +361,10 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
forward_requires_fresh_args = True
def inputs_dict(self, seed=None):
generator = torch.Generator("cpu")
if seed is not None:
generator.manual_seed(0)
if seed is None:
generator = torch.Generator("cpu").manual_seed(0)
else:
generator = torch.Generator("cpu").manual_seed(seed)
image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
return {"sample": image, "generator": generator}
......
......@@ -905,11 +905,13 @@ class ModelTesterMixin:
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
self.assertTrue(actual_num_shards == expected_num_shards)
new_model = self.model_class.from_pretrained(tmp_dir)
new_model = self.model_class.from_pretrained(tmp_dir).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_gpu
......@@ -940,6 +942,7 @@ class ModelTesterMixin:
new_model = new_model.to(torch_device)
torch.manual_seed(0)
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment