Unverified Commit b9b89162 authored by Andranik Movsisyan's avatar Andranik Movsisyan Committed by GitHub
Browse files

Text2video zero refinements (#3070)

* fix progress bar issue in pipeline_text_to_video_zero.py. Copy scheduler after first backward

* fix tensor loading in test_text_to_video_zero.py

* make style && make quality
parent a4393437
import copy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
...@@ -56,8 +57,8 @@ class CrossFrameAttnProcessor: ...@@ -56,8 +57,8 @@ class CrossFrameAttnProcessor:
is_cross_attention = encoder_hidden_states is not None is_cross_attention = encoder_hidden_states is not None
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.cross_attention_norm: elif attn.norm_cross:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
...@@ -285,7 +286,8 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): ...@@ -285,7 +286,8 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
latents: latents of backward process output at time timesteps[-1] latents: latents of backward process output at time timesteps[-1]
""" """
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
with self.progress_bar(total=len(timesteps)) as progress_bar: num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order
with self.progress_bar(total=num_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
...@@ -465,6 +467,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): ...@@ -465,6 +467,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
extra_step_kwargs=extra_step_kwargs, extra_step_kwargs=extra_step_kwargs,
num_warmup_steps=num_warmup_steps, num_warmup_steps=num_warmup_steps,
) )
scheduler_copy = copy.deepcopy(self.scheduler)
# Perform the second backward process up to time T_0 # Perform the second backward process up to time T_0
x_1_t0 = self.backward_loop( x_1_t0 = self.backward_loop(
...@@ -475,7 +478,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): ...@@ -475,7 +478,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
callback=callback, callback=callback,
callback_steps=callback_steps, callback_steps=callback_steps,
extra_step_kwargs=extra_step_kwargs, extra_step_kwargs=extra_step_kwargs,
num_warmup_steps=num_warmup_steps, num_warmup_steps=0,
) )
# Propagate first frame latents at time T_0 to remaining frames # Propagate first frame latents at time T_0 to remaining frames
...@@ -502,7 +505,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): ...@@ -502,7 +505,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
b, l, d = prompt_embeds.size() b, l, d = prompt_embeds.size()
prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d) prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d)
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler = scheduler_copy
x_1k_0 = self.backward_loop( x_1k_0 = self.backward_loop(
timesteps=timesteps[-t1 - 1 :], timesteps=timesteps[-t1 - 1 :],
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
...@@ -511,7 +514,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): ...@@ -511,7 +514,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
callback=callback, callback=callback,
callback_steps=callback_steps, callback_steps=callback_steps,
extra_step_kwargs=extra_step_kwargs, extra_step_kwargs=extra_step_kwargs,
num_warmup_steps=num_warmup_steps, num_warmup_steps=0,
) )
latents = x_1k_0 latents = x_1k_0
......
...@@ -86,6 +86,7 @@ if is_torch_available(): ...@@ -86,6 +86,7 @@ if is_torch_available():
load_hf_numpy, load_hf_numpy,
load_image, load_image,
load_numpy, load_numpy,
load_pt,
nightly, nightly,
parse_flag_from_env, parse_flag_from_env,
print_tensor_test, print_tensor_test,
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
import torch import torch
from diffusers import DDIMScheduler, TextToVideoZeroPipeline from diffusers import DDIMScheduler, TextToVideoZeroPipeline
from diffusers.utils import require_torch_gpu, slow from diffusers.utils import load_pt, require_torch_gpu, slow
from ...test_pipelines_common import assert_mean_pixel_difference from ...test_pipelines_common import assert_mean_pixel_difference
...@@ -35,8 +35,8 @@ class TextToVideoZeroPipelineSlowTests(unittest.TestCase): ...@@ -35,8 +35,8 @@ class TextToVideoZeroPipelineSlowTests(unittest.TestCase):
prompt = "A bear is playing a guitar on Times Square" prompt = "A bear is playing a guitar on Times Square"
result = pipe(prompt=prompt, generator=generator).images result = pipe(prompt=prompt, generator=generator).images
expected_result = torch.load( expected_result = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/tree/main/text-to-video/A bear is playing a guitar on Times Square.pt" "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/A bear is playing a guitar on Times Square.pt"
) )
assert_mean_pixel_difference(result, expected_result) assert_mean_pixel_difference(result, expected_result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment