Unverified Commit 52eb0348 authored by fboulnois's avatar fboulnois Committed by GitHub
Browse files

Standardize on using `image` argument in all pipelines (#1361)

* feat: switch core pipelines to use image arg

* test: update tests for core pipelines

* feat: switch examples to use image arg

* docs: update docs to use image arg

* style: format code using black and doc-builder

* fix: deprecate use of init_image in all pipelines
parent 2bbf8b67
...@@ -79,7 +79,7 @@ class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase ...@@ -79,7 +79,7 @@ class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase
init_image = self.dummy_image.to(device) init_image = self.dummy_image.to(device)
generator = torch.Generator(device=device).manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images image = ldm(image=init_image, generator=generator, num_inference_steps=2, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -124,7 +124,7 @@ class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase): ...@@ -124,7 +124,7 @@ class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase):
ldm.set_progress_bar_config(disable=None) ldm.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ldm(init_image, generator=generator, num_inference_steps=20, output_type="numpy").images image = ldm(image=init_image, generator=generator, num_inference_steps=20, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
......
...@@ -186,7 +186,7 @@ class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -186,7 +186,7 @@ class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
source_prompt=source_prompt, source_prompt=source_prompt,
generator=generator, generator=generator,
num_inference_steps=2, num_inference_steps=2,
init_image=init_image, image=init_image,
eta=0.1, eta=0.1,
strength=0.8, strength=0.8,
guidance_scale=3, guidance_scale=3,
...@@ -244,7 +244,7 @@ class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -244,7 +244,7 @@ class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
source_prompt=source_prompt, source_prompt=source_prompt,
generator=generator, generator=generator,
num_inference_steps=2, num_inference_steps=2,
init_image=init_image, image=init_image,
eta=0.1, eta=0.1,
strength=0.8, strength=0.8,
guidance_scale=3, guidance_scale=3,
...@@ -297,7 +297,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -297,7 +297,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
output = pipe( output = pipe(
prompt=prompt, prompt=prompt,
source_prompt=source_prompt, source_prompt=source_prompt,
init_image=init_image, image=init_image,
num_inference_steps=100, num_inference_steps=100,
eta=0.1, eta=0.1,
strength=0.85, strength=0.85,
...@@ -336,7 +336,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -336,7 +336,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
output = pipe( output = pipe(
prompt=prompt, prompt=prompt,
source_prompt=source_prompt, source_prompt=source_prompt,
init_image=init_image, image=init_image,
num_inference_steps=100, num_inference_steps=100,
eta=0.1, eta=0.1,
strength=0.85, strength=0.85,
......
...@@ -72,7 +72,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -72,7 +72,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
generator = np.random.RandomState(0) generator = np.random.RandomState(0)
output = pipe( output = pipe(
prompt=prompt, prompt=prompt,
init_image=init_image, image=init_image,
strength=0.75, strength=0.75,
guidance_scale=7.5, guidance_scale=7.5,
num_inference_steps=10, num_inference_steps=10,
...@@ -110,7 +110,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -110,7 +110,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
generator = np.random.RandomState(0) generator = np.random.RandomState(0)
output = pipe( output = pipe(
prompt=prompt, prompt=prompt,
init_image=init_image, image=init_image,
strength=0.75, strength=0.75,
guidance_scale=7.5, guidance_scale=7.5,
num_inference_steps=10, num_inference_steps=10,
......
...@@ -80,7 +80,7 @@ class StableDiffusionOnnxInpaintLegacyPipelineIntegrationTests(unittest.TestCase ...@@ -80,7 +80,7 @@ class StableDiffusionOnnxInpaintLegacyPipelineIntegrationTests(unittest.TestCase
generator = np.random.RandomState(0) generator = np.random.RandomState(0)
output = pipe( output = pipe(
prompt=prompt, prompt=prompt,
init_image=init_image, image=init_image,
mask_image=mask_image, mask_image=mask_image,
strength=0.75, strength=0.75,
guidance_scale=7.5, guidance_scale=7.5,
......
...@@ -188,7 +188,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -188,7 +188,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
guidance_scale=6.0, guidance_scale=6.0,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
) )
image = output.images image = output.images
...@@ -200,7 +200,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -200,7 +200,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
guidance_scale=6.0, guidance_scale=6.0,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
return_dict=False, return_dict=False,
)[0] )[0]
...@@ -245,7 +245,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -245,7 +245,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
guidance_scale=6.0, guidance_scale=6.0,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
) )
image = output.images image = output.images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -285,7 +285,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -285,7 +285,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
guidance_scale=6.0, guidance_scale=6.0,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
) )
image = output.images image = output.images
...@@ -328,7 +328,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -328,7 +328,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
guidance_scale=6.0, guidance_scale=6.0,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
) )
image = output.images image = output.images
...@@ -339,7 +339,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -339,7 +339,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
guidance_scale=6.0, guidance_scale=6.0,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
return_dict=False, return_dict=False,
) )
image_from_tuple = output[0] image_from_tuple = output[0]
...@@ -382,7 +382,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -382,7 +382,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
prompt, prompt,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
).images ).images
assert images.shape == (1, 32, 32, 3) assert images.shape == (1, 32, 32, 3)
...@@ -393,7 +393,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -393,7 +393,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
[prompt] * batch_size, [prompt] * batch_size,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
).images ).images
assert images.shape == (batch_size, 32, 32, 3) assert images.shape == (batch_size, 32, 32, 3)
...@@ -404,7 +404,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -404,7 +404,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
prompt, prompt,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
).images ).images
...@@ -416,7 +416,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -416,7 +416,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
[prompt] * batch_size, [prompt] * batch_size,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
).images ).images
...@@ -458,7 +458,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -458,7 +458,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
generator=generator, generator=generator,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
).images ).images
assert image.shape == (1, 32, 32, 3) assert image.shape == (1, 32, 32, 3)
...@@ -497,7 +497,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -497,7 +497,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe( output = pipe(
prompt=prompt, prompt=prompt,
init_image=init_image, image=init_image,
strength=0.75, strength=0.75,
guidance_scale=7.5, guidance_scale=7.5,
generator=generator, generator=generator,
...@@ -535,7 +535,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -535,7 +535,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe( output = pipe(
prompt=prompt, prompt=prompt,
init_image=init_image, image=init_image,
strength=0.75, strength=0.75,
guidance_scale=7.5, guidance_scale=7.5,
generator=generator, generator=generator,
...@@ -572,7 +572,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -572,7 +572,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe( output = pipe(
prompt=prompt, prompt=prompt,
init_image=init_image, image=init_image,
strength=0.75, strength=0.75,
guidance_scale=7.5, guidance_scale=7.5,
generator=generator, generator=generator,
...@@ -626,7 +626,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -626,7 +626,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
with torch.autocast(torch_device): with torch.autocast(torch_device):
pipe( pipe(
prompt=prompt, prompt=prompt,
init_image=init_image, image=init_image,
strength=0.75, strength=0.75,
num_inference_steps=50, num_inference_steps=50,
guidance_scale=7.5, guidance_scale=7.5,
...@@ -663,7 +663,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -663,7 +663,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
_ = pipe( _ = pipe(
prompt=prompt, prompt=prompt,
init_image=init_image, image=init_image,
strength=0.75, strength=0.75,
guidance_scale=7.5, guidance_scale=7.5,
generator=generator, generator=generator,
......
...@@ -191,7 +191,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes ...@@ -191,7 +191,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes
guidance_scale=6.0, guidance_scale=6.0,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
mask_image=mask_image, mask_image=mask_image,
) )
...@@ -204,7 +204,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes ...@@ -204,7 +204,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes
guidance_scale=6.0, guidance_scale=6.0,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
mask_image=mask_image, mask_image=mask_image,
return_dict=False, return_dict=False,
)[0] )[0]
...@@ -252,7 +252,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes ...@@ -252,7 +252,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes
guidance_scale=6.0, guidance_scale=6.0,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
mask_image=mask_image, mask_image=mask_image,
) )
...@@ -295,7 +295,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes ...@@ -295,7 +295,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes
prompt, prompt,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
mask_image=mask_image, mask_image=mask_image,
).images ).images
...@@ -307,7 +307,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes ...@@ -307,7 +307,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes
[prompt] * batch_size, [prompt] * batch_size,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
mask_image=mask_image, mask_image=mask_image,
).images ).images
...@@ -319,7 +319,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes ...@@ -319,7 +319,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes
prompt, prompt,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
mask_image=mask_image, mask_image=mask_image,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
).images ).images
...@@ -332,7 +332,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes ...@@ -332,7 +332,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes
[prompt] * batch_size, [prompt] * batch_size,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
mask_image=mask_image, mask_image=mask_image,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
).images ).images
...@@ -374,7 +374,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase): ...@@ -374,7 +374,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe( output = pipe(
prompt=prompt, prompt=prompt,
init_image=init_image, image=init_image,
mask_image=mask_image, mask_image=mask_image,
strength=0.75, strength=0.75,
guidance_scale=7.5, guidance_scale=7.5,
...@@ -416,7 +416,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase): ...@@ -416,7 +416,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe( output = pipe(
prompt=prompt, prompt=prompt,
init_image=init_image, image=init_image,
mask_image=mask_image, mask_image=mask_image,
strength=0.75, strength=0.75,
guidance_scale=7.5, guidance_scale=7.5,
...@@ -474,7 +474,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase): ...@@ -474,7 +474,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
with torch.autocast(torch_device): with torch.autocast(torch_device):
pipe( pipe(
prompt=prompt, prompt=prompt,
init_image=init_image, image=init_image,
mask_image=mask_image, mask_image=mask_image,
strength=0.75, strength=0.75,
num_inference_steps=50, num_inference_steps=50,
......
...@@ -411,7 +411,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -411,7 +411,7 @@ class PipelineFastTests(unittest.TestCase):
generator=generator, generator=generator,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
mask_image=mask_image, mask_image=mask_image,
).images ).images
image_img2img = img2img( image_img2img = img2img(
...@@ -419,7 +419,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -419,7 +419,7 @@ class PipelineFastTests(unittest.TestCase):
generator=generator, generator=generator,
num_inference_steps=2, num_inference_steps=2,
output_type="np", output_type="np",
init_image=init_image, image=init_image,
).images ).images
image_text2img = text2img( image_text2img = text2img(
[prompt], [prompt],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment