Unverified Commit 1dc231d1 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[PixArt-Alpha] Support non-square images (#5672)



* debug

* support non-square images

* add: test

* fix: test

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 84cd9e8d
...@@ -339,6 +339,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -339,6 +339,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
elif self.is_input_vectorized: elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states) hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches: elif self.is_input_patches:
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = self.pos_embed(hidden_states) hidden_states = self.pos_embed(hidden_states)
if self.adaln_single is not None: if self.adaln_single is not None:
...@@ -425,6 +426,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -425,6 +426,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
hidden_states = hidden_states.squeeze(1) hidden_states = hidden_states.squeeze(1)
# unpatchify # unpatchify
if self.adaln_single is None:
height = width = int(hidden_states.shape[1] ** 0.5) height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape( hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
......
...@@ -174,13 +174,29 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -174,13 +174,29 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
print(torch.from_numpy(image_slice.flatten()))
self.assertEqual(image.shape, (1, 8, 8, 3)) self.assertEqual(image.shape, (1, 8, 8, 3))
expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675]) expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675])
max_diff = np.abs(image_slice.flatten() - expected_slice).max() max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3) self.assertLessEqual(max_diff, 1e-3)
def test_inference_non_square_images(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs, height=32, width=48).images
image_slice = image[0, -3:, -3:, -1]
self.assertEqual(image.shape, (1, 32, 48, 3))
expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_inference_with_embeddings_and_multiple_images(self): def test_inference_with_embeddings_and_multiple_images(self):
components = self.get_dummy_components() components = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment