"docs/vscode:/vscode.git/clone" did not exist on "e00df25aee329ef84450066019701c8d5075cfb3"
Commit e3bf9324 authored by patil-suraj's avatar patil-suraj
Browse files

don't hardcode device in tests

parent dc966cc4
...@@ -262,8 +262,6 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -262,8 +262,6 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
sizes = (32, 32) sizes = (32, 32)
low_res_size = (4, 4) low_res_size = (4, 4)
torch_device = "cpu"
noise = torch.randn((batch_size, num_channels // 2) + sizes).to(torch_device) noise = torch.randn((batch_size, num_channels // 2) + sizes).to(torch_device)
low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device) low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device)
time_step = torch.tensor([10] * noise.shape[0], device=torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
...@@ -355,8 +353,6 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -355,8 +353,6 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
transformer_dim = 32 transformer_dim = 32
seq_len = 16 seq_len = 16
torch_device = "cpu"
noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device) noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device) emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device)
time_step = torch.tensor([10] * noise.shape[0], device=torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment