Unverified Commit 013955b5 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Dit] Fix dit tests (#2034)

* [Dit] Fix dit tests

* up
parent ed616bd8
......@@ -36,10 +36,10 @@ class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
transformer = Transformer2DModel(
sample_size=4,
sample_size=16,
num_layers=2,
patch_size=2,
attention_head_dim=2,
patch_size=4,
attention_head_dim=8,
num_attention_heads=2,
in_channels=4,
out_channels=8,
......@@ -79,10 +79,8 @@ class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
self.assertEqual(image.shape, (1, 4, 4, 3))
expected_slice = np.array(
[0.44405967, 0.33592293, 0.6093237, 0.48981372, 0.79098296, 0.7504172, 0.59413105, 0.49462673, 0.35190058]
)
self.assertEqual(image.shape, (1, 16, 16, 3))
expected_slice = np.array([0.4380, 0.4141, 0.5159, 0.0000, 0.4282, 0.6680, 0.5485, 0.2545, 0.6719])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment