"server/vscode:/vscode.git/clone" did not exist on "5c7c9f13903f09636aaf99210710bf07002cdb87"
Unverified Commit 11b3002b authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

Support views batch for panorama (#3632)



* support views batch for panorama

* add entry for the new argument

* format entry for the new argument

* add view_batch_size test

* fix batch test and a boundary condition

* add more docstrings

* fix a typos

* fix typos

* add: entry to the doc about view_batch_size.

* Revert "add: entry to the doc about view_batch_size."

This reverts commit a36aeaa9edf9b662d09bbfd6e18cbc556ed38187.

* add a tip on .

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 10f4ecd1
......@@ -52,6 +52,14 @@ image = pipe(prompt).images[0]
image.save("dolomites.png")
```
<Tip>
While calling this pipeline, it's possible to specify the `view_batch_size` to have a >1 value.
For some GPUs with high performance, higher a `view_batch_size`, can speedup the generation
and increase the VRAM usage.
</Tip>
## StableDiffusionPanoramaPipeline
[[autodoc]] StableDiffusionPanoramaPipeline
- __call__
......
......@@ -451,10 +451,11 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8):
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
# if panorama's height/width < window_size, num_blocks of height/width should return 1
panorama_height /= 8
panorama_width /= 8
num_blocks_height = (panorama_height - window_size) // stride + 1
num_blocks_width = (panorama_width - window_size) // stride + 1
num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1
num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_height > window_size else 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
views = []
for i in range(total_num_blocks):
......@@ -474,6 +475,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
width: Optional[int] = 2048,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
view_batch_size: int = 1,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
......@@ -508,6 +510,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
view_batch_size (`int`, *optional*, defaults to 1):
The batch size to denoise splited views. For some GPUs with high performance, higher view batch size
can speedup the generation and increase the VRAM usage.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
......@@ -609,8 +614,11 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
)
# 6. Define panorama grid and initialize views for synthesis.
# prepare batch grid
views = self.get_views(height, width)
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views)
views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch)
count = torch.zeros_like(latents)
value = torch.zeros_like(latents)
......@@ -631,42 +639,55 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# denoised (latent) crops are then averaged to produce the final latent
# for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
# MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
for j, (h_start, h_end, w_start, w_end) in enumerate(views):
# Batch views denoise
for j, batch_view in enumerate(views_batch):
vb_size = len(batch_view)
# get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
latents_for_view = torch.cat(
[latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view]
)
# rematch block's scheduler status
self.scheduler.__dict__.update(views_scheduler_status[j])
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
latents_for_view.repeat_interleave(2, dim=0)
if do_classifier_free_guidance
else latents_for_view
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# repeat prompt_embeds for batch
prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states=prompt_embeds_input,
cross_attention_kwargs=cross_attention_kwargs,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_view_denoised = self.scheduler.step(
latents_denoised_batch = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs
).prev_sample
# save views scheduler status after sample
views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__)
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
# extract value from batch
for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
latents_denoised_batch.chunk(vb_size), batch_view
):
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = torch.where(count > 0, value / count, value)
......
......@@ -131,7 +131,7 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
# override to speed the overall test timing up.
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=3e-3)
super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=3.25e-3)
def test_stable_diffusion_panorama_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
......@@ -152,6 +152,24 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_panorama_views_batch(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPanoramaPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs, view_batch_size=2)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.6187, 0.5375, 0.4915, 0.4136, 0.4114, 0.4563, 0.5128, 0.4976, 0.4757])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_panorama_euler(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment