Unverified Commit 7d6f30e8 authored by Junsong Chen's avatar Junsong Chen Committed by GitHub
Browse files

[Fix: pixart-alpha] random 512px resolution bug (#5842)



* [Fix: pixart-alpha]
add ASPECT_RATIO_512_BIN in use_resolution_binning for random 512px image generation.

* add slow test file for 512px generation without resolution binning

* fix: slow tests for resolution binning.

---------
Co-authored-by: default avatarjschen <chenjunsong4@h-partners.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 6d2e19f7
...@@ -97,6 +97,42 @@ ASPECT_RATIO_1024_BIN = { ...@@ -97,6 +97,42 @@ ASPECT_RATIO_1024_BIN = {
"4.0": [2048.0, 512.0], "4.0": [2048.0, 512.0],
} }
ASPECT_RATIO_512_BIN = {
"0.25": [256.0, 1024.0],
"0.28": [256.0, 928.0],
"0.32": [288.0, 896.0],
"0.33": [288.0, 864.0],
"0.35": [288.0, 832.0],
"0.4": [320.0, 800.0],
"0.42": [320.0, 768.0],
"0.48": [352.0, 736.0],
"0.5": [352.0, 704.0],
"0.52": [352.0, 672.0],
"0.57": [384.0, 672.0],
"0.6": [384.0, 640.0],
"0.68": [416.0, 608.0],
"0.72": [416.0, 576.0],
"0.78": [448.0, 576.0],
"0.82": [448.0, 544.0],
"0.88": [480.0, 544.0],
"0.94": [480.0, 512.0],
"1.0": [512.0, 512.0],
"1.07": [512.0, 480.0],
"1.13": [544.0, 480.0],
"1.21": [544.0, 448.0],
"1.29": [576.0, 448.0],
"1.38": [576.0, 416.0],
"1.46": [608.0, 416.0],
"1.67": [640.0, 384.0],
"1.75": [672.0, 384.0],
"2.0": [704.0, 352.0],
"2.09": [736.0, 352.0],
"2.4": [768.0, 320.0],
"2.5": [800.0, 320.0],
"3.0": [864.0, 288.0],
"4.0": [1024.0, 256.0],
}
class PixArtAlphaPipeline(DiffusionPipeline): class PixArtAlphaPipeline(DiffusionPipeline):
r""" r"""
...@@ -691,8 +727,11 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -691,8 +727,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
height = height or self.transformer.config.sample_size * self.vae_scale_factor height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor
if use_resolution_binning: if use_resolution_binning:
aspect_ratio_bin = (
ASPECT_RATIO_1024_BIN if self.transformer.config.sample_size == 128 else ASPECT_RATIO_512_BIN
)
orig_height, orig_width = height, width orig_height, orig_width = height, width
height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN) height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
self.check_inputs( self.check_inputs(
prompt, prompt,
......
...@@ -320,6 +320,10 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -320,6 +320,10 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow @slow
@require_torch_gpu @require_torch_gpu
class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
ckpt_id_1024 = "PixArt-alpha/PixArt-XL-2-1024-MS"
ckpt_id_512 = "PixArt-alpha/PixArt-XL-2-512x512"
prompt = "A small cactus with a happy face in the Sahara desert."
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
...@@ -328,10 +332,10 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): ...@@ -328,10 +332,10 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
def test_pixart_1024_fast(self): def test_pixart_1024_fast(self):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_1024, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert." prompt = self.prompt
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
...@@ -345,10 +349,10 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): ...@@ -345,10 +349,10 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
def test_pixart_512_fast(self): def test_pixart_512_fast(self):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-512x512", torch_dtype=torch.float16) pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert." prompt = self.prompt
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
...@@ -362,9 +366,9 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): ...@@ -362,9 +366,9 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
def test_pixart_1024(self): def test_pixart_1024(self):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_1024, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert." prompt = self.prompt
image = pipe(prompt, generator=generator, output_type="np").images image = pipe(prompt, generator=generator, output_type="np").images
...@@ -378,10 +382,10 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): ...@@ -378,10 +382,10 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
def test_pixart_512(self): def test_pixart_512(self):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-512x512", torch_dtype=torch.float16) pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert." prompt = self.prompt
image = pipe(prompt, generator=generator, output_type="np").images image = pipe(prompt, generator=generator, output_type="np").images
...@@ -395,17 +399,66 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): ...@@ -395,17 +399,66 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
def test_pixart_1024_without_resolution_binning(self): def test_pixart_1024_without_resolution_binning(self):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_1024, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert." prompt = self.prompt
height, width = 1024, 768
num_inference_steps = 10
image = pipe(
prompt,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
output_type="np",
).images
image_slice = image[0, -3:, -3:, -1]
generator = torch.manual_seed(0)
no_res_bin_image = pipe(
prompt,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
output_type="np",
use_resolution_binning=False,
).images
no_res_bin_image_slice = no_res_bin_image[0, -3:, -3:, -1]
image = pipe(prompt, generator=generator, num_inference_steps=5, output_type="np").images assert not np.allclose(image_slice, no_res_bin_image_slice, atol=1e-4, rtol=1e-4)
def test_pixart_512_without_resolution_binning(self):
generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
prompt = self.prompt
height, width = 512, 768
num_inference_steps = 10
image = pipe(
prompt,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
output_type="np",
).images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
no_res_bin_image = pipe( no_res_bin_image = pipe(
prompt, generator=generator, num_inference_steps=5, output_type="np", use_resolution_binning=False prompt,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
output_type="np",
use_resolution_binning=False,
).images ).images
no_res_bin_image_slice = no_res_bin_image[0, -3:, -3:, -1] no_res_bin_image_slice = no_res_bin_image[0, -3:, -3:, -1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment