Unverified Commit d63a498c authored by vahramtadevosyan's avatar vahramtadevosyan Committed by GitHub
Browse files

[Pipeline] Add TextToVideoZeroSDXLPipeline (#4695)



* integrated sdxl for the text2video-zero pipeline

* make fix-copies

* fixed CI issues

* make fix-copies

* added docs and `copied from` statements

* added fast tests

* made a small change in docs

* quality+style check fix

* updated docs. added controlnet inference with sdxl

* added device compatibility for fast tests

* fixed docstrings

* changing vae upcasting

* remove torch.empty_cache to speed up inference
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* made fast tests to run on dummy models only, fixed copied from statements

* fixed testing utils imports

* Added bullet points for SDXL support

* fixed formatting & quality

* Update tests/pipelines/text_to_video/test_text_to_video_zero_sdxl.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update tests/pipelines/text_to_video/test_text_to_video_zero_sdxl.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fixed minor error for merging

* fixed updates of sdxl

* made fast tests inherit from `PipelineTesterMixin` and run in 3-4secs on CPU

* make style && make quality

* reimplemented fast tests w/o default attn processor

* make style & make quality

* make fix-copies

* make fix-copies

* fixed docs

* make style & make quality & make fix-copies

* bug fix in cross attention

* make style && make quality

* make fix-copies

* fix gpu issues

* make fix-copies

* updated pipeline signature

---------
Co-authored-by: default avatarVahram <vahram.tadevosyan@lambda-loginnode02.cm.cluster>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 6a4aad43
...@@ -92,6 +92,19 @@ imageio.mimsave("video.mp4", result, fps=4) ...@@ -92,6 +92,19 @@ imageio.mimsave("video.mp4", result, fps=4)
``` ```
- #### SDXL Support
In order to use the SDXL model when generating a video from prompt, use the `TextToVideoZeroSDXLPipeline` pipeline:
```python
import torch
from diffusers import TextToVideoZeroSDXLPipeline
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = TextToVideoZeroSDXLPipeline.from_pretrained(
model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
```
### Text-To-Video with Pose Control ### Text-To-Video with Pose Control
To generate a video from prompt with additional pose control To generate a video from prompt with additional pose control
...@@ -141,7 +154,33 @@ To generate a video from prompt with additional pose control ...@@ -141,7 +154,33 @@ To generate a video from prompt with additional pose control
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
imageio.mimsave("video.mp4", result, fps=4) imageio.mimsave("video.mp4", result, fps=4)
``` ```
- #### SDXL Support
Since our attention processor also works with SDXL, it can be utilized to generate a video from prompt using ControlNet models powered by SDXL:
```python
import torch
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor
controlnet_model_id = 'thibaud/controlnet-openpose-sdxl-1.0'
model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
controlnet = ControlNetModel.from_pretrained(controlnet_model_id, torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_id, controlnet=controlnet, torch_dtype=torch.float16
).to('cuda')
# Set the attention processor
pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
# fix latents for all frames
latents = torch.randn((1, 4, 128, 128), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)
prompt = "Darth Vader dancing in a desert"
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
imageio.mimsave("video.mp4", result, fps=4)
```
### Text-To-Video with Edge Control ### Text-To-Video with Edge Control
...@@ -253,5 +292,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) ...@@ -253,5 +292,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- all - all
- __call__ - __call__
## TextToVideoZeroSDXLPipeline
[[autodoc]] TextToVideoZeroSDXLPipeline
- all
- __call__
## TextToVideoPipelineOutput ## TextToVideoPipelineOutput
[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoPipelineOutput [[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoPipelineOutput
...@@ -279,6 +279,7 @@ else: ...@@ -279,6 +279,7 @@ else:
"StableUnCLIPPipeline", "StableUnCLIPPipeline",
"TextToVideoSDPipeline", "TextToVideoSDPipeline",
"TextToVideoZeroPipeline", "TextToVideoZeroPipeline",
"TextToVideoZeroSDXLPipeline",
"UnCLIPImageVariationPipeline", "UnCLIPImageVariationPipeline",
"UnCLIPPipeline", "UnCLIPPipeline",
"UniDiffuserModel", "UniDiffuserModel",
...@@ -628,6 +629,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -628,6 +629,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableUnCLIPPipeline, StableUnCLIPPipeline,
TextToVideoSDPipeline, TextToVideoSDPipeline,
TextToVideoZeroPipeline, TextToVideoZeroPipeline,
TextToVideoZeroSDXLPipeline,
UnCLIPImageVariationPipeline, UnCLIPImageVariationPipeline,
UnCLIPPipeline, UnCLIPPipeline,
UniDiffuserModel, UniDiffuserModel,
......
...@@ -162,6 +162,7 @@ else: ...@@ -162,6 +162,7 @@ else:
_import_structure["text_to_video_synthesis"] = [ _import_structure["text_to_video_synthesis"] = [
"TextToVideoSDPipeline", "TextToVideoSDPipeline",
"TextToVideoZeroPipeline", "TextToVideoZeroPipeline",
"TextToVideoZeroSDXLPipeline",
"VideoToVideoSDPipeline", "VideoToVideoSDPipeline",
] ]
_import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"] _import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"]
...@@ -386,6 +387,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -386,6 +387,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .text_to_video_synthesis import ( from .text_to_video_synthesis import (
TextToVideoSDPipeline, TextToVideoSDPipeline,
TextToVideoZeroPipeline, TextToVideoZeroPipeline,
TextToVideoZeroSDXLPipeline,
VideoToVideoSDPipeline, VideoToVideoSDPipeline,
) )
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
......
...@@ -25,6 +25,7 @@ else: ...@@ -25,6 +25,7 @@ else:
_import_structure["pipeline_text_to_video_synth"] = ["TextToVideoSDPipeline"] _import_structure["pipeline_text_to_video_synth"] = ["TextToVideoSDPipeline"]
_import_structure["pipeline_text_to_video_synth_img2img"] = ["VideoToVideoSDPipeline"] _import_structure["pipeline_text_to_video_synth_img2img"] = ["VideoToVideoSDPipeline"]
_import_structure["pipeline_text_to_video_zero"] = ["TextToVideoZeroPipeline"] _import_structure["pipeline_text_to_video_zero"] = ["TextToVideoZeroPipeline"]
_import_structure["pipeline_text_to_video_zero_sdxl"] = ["TextToVideoZeroSDXLPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
...@@ -38,6 +39,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -38,6 +39,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_text_to_video_synth import TextToVideoSDPipeline from .pipeline_text_to_video_synth import TextToVideoSDPipeline
from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline
from .pipeline_text_to_video_zero import TextToVideoZeroPipeline from .pipeline_text_to_video_zero import TextToVideoZeroPipeline
from .pipeline_text_to_video_zero_sdxl import TextToVideoZeroSDXLPipeline
else: else:
import sys import sys
......
...@@ -13,6 +13,7 @@ from diffusers.models import AutoencoderKL, UNet2DConditionModel ...@@ -13,6 +13,7 @@ from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import BaseOutput from diffusers.utils import BaseOutput
from diffusers.utils.torch_utils import randn_tensor
def rearrange_0(tensor, f): def rearrange_0(tensor, f):
...@@ -135,7 +136,7 @@ class CrossFrameAttnProcessor2_0: ...@@ -135,7 +136,7 @@ class CrossFrameAttnProcessor2_0:
# Cross Frame Attention # Cross Frame Attention
if not is_cross_attention: if not is_cross_attention:
video_length = key.size()[0] // self.batch_size video_length = max(1, key.size()[0] // self.batch_size)
first_frame_index = [0] * video_length first_frame_index = [0] * video_length
# rearrange keys to have batch and frames in the 1st and 2nd dims respectively # rearrange keys to have batch and frames in the 1st and 2nd dims respectively
...@@ -339,7 +340,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline): ...@@ -339,7 +340,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
x_t1: x_t1:
Forward process applied to x_t0 from time t0 to t1. Forward process applied to x_t0 from time t0 to t1.
""" """
eps = torch.randn(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device) eps = randn_tensor(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device)
alpha_vec = torch.prod(self.scheduler.alphas[t0:t1]) alpha_vec = torch.prod(self.scheduler.alphas[t0:t1])
x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps
return x_t1 return x_t1
......
...@@ -1202,6 +1202,21 @@ class TextToVideoZeroPipeline(metaclass=DummyObject): ...@@ -1202,6 +1202,21 @@ class TextToVideoZeroPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class TextToVideoZeroSDXLPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class UnCLIPImageVariationPipeline(metaclass=DummyObject): class UnCLIPImageVariationPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import inspect
import io
import re
import tempfile
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, TextToVideoZeroSDXLPipeline, UNet2DConditionModel
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = TextToVideoZeroSDXLPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
generator_device = "cpu"
def get_dummy_components(self, seed=0):
torch.manual_seed(seed)
unet = UNet2DConditionModel(
block_out_channels=(2, 4),
layers_per_block=2,
sample_size=2,
norm_num_groups=2,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below
attention_head_dim=(2, 4),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
)
scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
clip_sample=True,
set_alpha_to_one=True,
steps_offset=0,
prediction_type="epsilon",
thresholding=False,
dynamic_thresholding_ratio=0.995,
clip_sample_range=1.0,
sample_max_value=1.0,
timestep_spacing="leading",
rescale_betas_zero_snr=False,
)
torch.manual_seed(seed)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
)
torch.manual_seed(seed)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
# SD2-specific config below
hidden_act="gelu",
projection_dim=32,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"image_encoder": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A panda dancing in Antarctica",
"generator": generator,
"num_inference_steps": 5,
"t0": 1,
"t1": 3,
"height": 64,
"width": 64,
"video_length": 3,
"output_type": "np",
}
return inputs
def get_generator(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
return generator
def test_text_to_video_zero_sdxl(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(self.generator_device)
result = pipe(**inputs).images
first_frame_slice = result[0, -3:, -3:, -1]
last_frame_slice = result[-1, -3:, -3:, 0]
expected_slice1 = np.array([0.48, 0.58, 0.53, 0.59, 0.50, 0.44, 0.60, 0.65, 0.52])
expected_slice2 = np.array([0.66, 0.49, 0.40, 0.70, 0.47, 0.51, 0.73, 0.65, 0.52])
assert np.abs(first_frame_slice.flatten() - expected_slice1).max() < 1e-2
assert np.abs(last_frame_slice.flatten() - expected_slice2).max() < 1e-2
@unittest.skip(
reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
)
def test_attention_slicing_forward_pass(self):
pass
def test_cfg(self):
sig = inspect.signature(self.pipeline_class.__call__)
if "guidance_scale" not in sig.parameters:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(self.generator_device)
inputs["guidance_scale"] = 1.0
out_no_cfg = pipe(**inputs)[0]
inputs["guidance_scale"] = 7.5
out_cfg = pipe(**inputs)[0]
assert out_cfg.shape == out_no_cfg.shape
def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(self.generator_device))[0]
output_tuple = pipe(**self.get_dummy_inputs(self.generator_device), return_dict=False)[0]
max_diff = np.abs(to_np(output) - to_np(output_tuple)).max()
self.assertLess(max_diff, expected_max_difference)
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
def test_float16_inference(self, expected_max_diff=5e-2):
components = self.get_dummy_components()
for name, module in components.items():
if hasattr(module, "half"):
components[name] = module.to(torch_device).half()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
components = self.get_dummy_components()
pipe_fp16 = self.pipeline_class(**components)
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(self.generator_device)
# # Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(self.generator_device)
output = pipe(**inputs)[0]
fp16_inputs = self.get_dummy_inputs(self.generator_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(self.generator_device)
output_fp16 = pipe_fp16(**fp16_inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
@unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
def test_inference_batch_consistent(self):
pass
@unittest.skip(
reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
)
def test_inference_batch_single_identical(self):
pass
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
)
def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(self.generator_device)
output_without_offload = pipe(**inputs)[0]
pipe.enable_model_cpu_offload()
inputs = self.get_dummy_inputs(self.generator_device)
output_with_offload = pipe(**inputs)[0]
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
@unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")
def test_pipeline_call_signature(self):
pass
def test_progress_bar(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
inputs = self.get_dummy_inputs(self.generator_device)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
stderr = stderr.getvalue()
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
max_steps = re.search("/(.*?) ", stderr).group(1)
self.assertTrue(max_steps is not None and len(max_steps) > 0)
self.assertTrue(
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
)
pipe.set_progress_bar_config(disable=True)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
if hasattr(module, "half"):
components[name] = module.to(torch_device).half()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(self.generator_device)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for name, component in pipe_loaded.components.items():
if hasattr(component, "dtype"):
self.assertTrue(
component.dtype == torch.float16,
f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
)
inputs = self.get_dummy_inputs(self.generator_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)
@unittest.skip(
reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
)
def test_save_load_local(self):
pass
@unittest.skip(
reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
)
def test_save_load_optional_components(self):
pass
@unittest.skip(
reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
)
def test_sequential_cpu_offload_forward_pass(self):
pass
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to("cuda")
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
self.assertTrue(all(device == "cuda" for device in model_devices))
output_cuda = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
@unittest.skip(
reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
)
def test_xformers_attention_forwardGenerator_pass(self):
pass
@nightly
@require_torch_gpu
class TextToVideoZeroSDXLPipelineSlowTests(unittest.TestCase):
def test_full_model(self):
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = self.pipeline_class.from_pretrained(
model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipe.enable_model_cpu_offload()
pipe.enable_vae_slicing()
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "A panda dancing in Antarctica"
result = pipe(prompt=prompt, generator=generator).images
first_frame_slice = result[0, -3:, -3:, -1]
last_frame_slice = result[-1, -3:, -3:, 0]
expected_slice1 = np.array([0.57, 0.57, 0.57, 0.57, 0.57, 0.56, 0.55, 0.56, 0.56])
expected_slice2 = np.array([0.54, 0.53, 0.53, 0.53, 0.53, 0.52, 0.53, 0.53, 0.53])
assert np.abs(first_frame_slice.flatten() - expected_slice1).max() < 1e-2
assert np.abs(last_frame_slice.flatten() - expected_slice2).max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment