Unverified Commit 0c72006e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix slow tsets (#3066)

* fix slow tsets

* make style
parent a89a14fa
......@@ -411,7 +411,9 @@ class Transformer2DModelTests(unittest.TestCase):
assert attention_scores.shape == (1, 64, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:]
expected_slice = torch.tensor([0.0143, -0.6909, -2.1547, -1.8893, 1.4097, 0.1359, -0.2521, -1.3359, 0.2598])
expected_slice = torch.tensor(
[0.0143, -0.6909, -2.1547, -1.8893, 1.4097, 0.1359, -0.2521, -1.3359, 0.2598], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_spatial_transformer_timestep(self):
......@@ -442,9 +444,11 @@ class Transformer2DModelTests(unittest.TestCase):
output_slice_1 = attention_scores_1[0, -1, -3:, -3:]
output_slice_2 = attention_scores_2[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.3923, -1.0923, -1.7144, -1.5570, 1.4154, 0.1738, -0.1157, -1.2998, -0.1703])
expected_slice = torch.tensor(
[-0.3923, -1.0923, -1.7144, -1.5570, 1.4154, 0.1738, -0.1157, -1.2998, -0.1703], device=torch_device
)
expected_slice_2 = torch.tensor(
[-0.4311, -1.1376, -1.7732, -1.5997, 1.3450, 0.0964, -0.1569, -1.3590, -0.2348]
[-0.4311, -1.1376, -1.7732, -1.5997, 1.3450, 0.0964, -0.1569, -1.3590, -0.2348], device=torch_device
)
assert torch.allclose(output_slice_1.flatten(), expected_slice, atol=1e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment