Unverified Commit 11b3002b authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

Support views batch for panorama (#3632)



* support views batch for panorama

* add entry for the new argument

* format entry for the new argument

* add view_batch_size test

* fix batch test and a boundary condition

* add more docstrings

* fix a typos

* fix typos

* add: entry to the doc about view_batch_size.

* Revert "add: entry to the doc about view_batch_size."

This reverts commit a36aeaa9edf9b662d09bbfd6e18cbc556ed38187.

* add a tip on .

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 10f4ecd1
...@@ -52,6 +52,14 @@ image = pipe(prompt).images[0] ...@@ -52,6 +52,14 @@ image = pipe(prompt).images[0]
image.save("dolomites.png") image.save("dolomites.png")
``` ```
<Tip>
While calling this pipeline, it's possible to specify the `view_batch_size` to have a >1 value.
For some GPUs with high performance, higher a `view_batch_size`, can speedup the generation
and increase the VRAM usage.
</Tip>
## StableDiffusionPanoramaPipeline ## StableDiffusionPanoramaPipeline
[[autodoc]] StableDiffusionPanoramaPipeline [[autodoc]] StableDiffusionPanoramaPipeline
- __call__ - __call__
......
...@@ -451,10 +451,11 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -451,10 +451,11 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8): def get_views(self, panorama_height, panorama_width, window_size=64, stride=8):
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
# if panorama's height/width < window_size, num_blocks of height/width should return 1
panorama_height /= 8 panorama_height /= 8
panorama_width /= 8 panorama_width /= 8
num_blocks_height = (panorama_height - window_size) // stride + 1 num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1
num_blocks_width = (panorama_width - window_size) // stride + 1 num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_height > window_size else 1
total_num_blocks = int(num_blocks_height * num_blocks_width) total_num_blocks = int(num_blocks_height * num_blocks_width)
views = [] views = []
for i in range(total_num_blocks): for i in range(total_num_blocks):
...@@ -474,6 +475,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -474,6 +475,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
width: Optional[int] = 2048, width: Optional[int] = 2048,
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
view_batch_size: int = 1,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
...@@ -508,6 +510,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -508,6 +510,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
view_batch_size (`int`, *optional*, defaults to 1):
The batch size to denoise splited views. For some GPUs with high performance, higher view batch size
can speedup the generation and increase the VRAM usage.
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
...@@ -609,8 +614,11 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -609,8 +614,11 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
) )
# 6. Define panorama grid and initialize views for synthesis. # 6. Define panorama grid and initialize views for synthesis.
# prepare batch grid
views = self.get_views(height, width) views = self.get_views(height, width)
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views) views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch)
count = torch.zeros_like(latents) count = torch.zeros_like(latents)
value = torch.zeros_like(latents) value = torch.zeros_like(latents)
...@@ -631,42 +639,55 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -631,42 +639,55 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# denoised (latent) crops are then averaged to produce the final latent # denoised (latent) crops are then averaged to produce the final latent
# for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
# MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113 # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
for j, (h_start, h_end, w_start, w_end) in enumerate(views): # Batch views denoise
for j, batch_view in enumerate(views_batch):
vb_size = len(batch_view)
# get the latents corresponding to the current view coordinates # get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end] latents_for_view = torch.cat(
[latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view]
)
# rematch block's scheduler status # rematch block's scheduler status
self.scheduler.__dict__.update(views_scheduler_status[j]) self.scheduler.__dict__.update(views_scheduler_status[j])
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = ( latent_model_input = (
torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view latents_for_view.repeat_interleave(2, dim=0)
if do_classifier_free_guidance
else latents_for_view
) )
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# repeat prompt_embeds for batch
prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds_input,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample ).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents_view_denoised = self.scheduler.step( latents_denoised_batch = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs noise_pred, t, latents_for_view, **extra_step_kwargs
).prev_sample ).prev_sample
# save views scheduler status after sample # save views scheduler status after sample
views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__)
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised # extract value from batch
count[:, :, h_start:h_end, w_start:w_end] += 1 for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
latents_denoised_batch.chunk(vb_size), batch_view
):
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = torch.where(count > 0, value / count, value) latents = torch.where(count > 0, value / count, value)
......
...@@ -131,7 +131,7 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli ...@@ -131,7 +131,7 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
# override to speed the overall test timing up. # override to speed the overall test timing up.
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=3e-3) super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=3.25e-3)
def test_stable_diffusion_panorama_negative_prompt(self): def test_stable_diffusion_panorama_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
...@@ -152,6 +152,24 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli ...@@ -152,6 +152,24 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_panorama_views_batch(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPanoramaPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs, view_batch_size=2)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.6187, 0.5375, 0.4915, 0.4136, 0.4114, 0.4563, 0.5128, 0.4976, 0.4757])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_panorama_euler(self): def test_stable_diffusion_panorama_euler(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment