Unverified Commit cd6e1f11 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[docs/nits] Fix return values based on `return_dict` and minor doc updates (#7105)



* fix returns and docs

* handle latent output_type correctly

* revert to old tensor2vid impl

* make fix-copies

* fix return in community animatediff pipes

* fix return docstring

* fix return docs

* add missing quote

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 6f2b310a
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -27,6 +26,7 @@ from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionL ...@@ -27,6 +26,7 @@ from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionL
from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models.unets.unet_motion_model import MotionAdapter from diffusers.models.unets.unet_motion_model import MotionAdapter
from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.schedulers import ( from diffusers.schedulers import (
...@@ -37,7 +37,7 @@ from diffusers.schedulers import ( ...@@ -37,7 +37,7 @@ from diffusers.schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
...@@ -91,10 +91,8 @@ EXAMPLE_DOC_STRING = """ ...@@ -91,10 +91,8 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor, output_type="np"): def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
batch_size, channels, num_frames, height, width = video.shape batch_size, channels, num_frames, height, width = video.shape
outputs = [] outputs = []
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
...@@ -103,12 +101,16 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): ...@@ -103,12 +101,16 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
outputs.append(batch_output) outputs.append(batch_output)
return outputs if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
@dataclass return outputs
class AnimateDiffControlNetPipelineOutput(BaseOutput):
frames: Union[torch.Tensor, np.ndarray]
class AnimateDiffControlNetPipeline( class AnimateDiffControlNetPipeline(
...@@ -843,8 +845,8 @@ class AnimateDiffControlNetPipeline( ...@@ -843,8 +845,8 @@ class AnimateDiffControlNetPipeline(
Examples: Examples:
Returns: Returns:
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
""" """
...@@ -1020,7 +1022,7 @@ class AnimateDiffControlNetPipeline( ...@@ -1020,7 +1022,7 @@ class AnimateDiffControlNetPipeline(
] ]
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
# Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
...@@ -1096,21 +1098,17 @@ class AnimateDiffControlNetPipeline( ...@@ -1096,21 +1098,17 @@ class AnimateDiffControlNetPipeline(
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 9. Post processing
if output_type == "latent": if output_type == "latent":
return AnimateDiffControlNetPipelineOutput(frames=latents) video = latents
# Post-processing
video_tensor = self.decode_latents(latents)
if output_type == "pt":
video = video_tensor
else: else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# Offload all models # 10. Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
return (video,) return (video,)
return AnimateDiffControlNetPipelineOutput(frames=video) return AnimateDiffPipelineOutput(frames=video)
...@@ -158,10 +158,8 @@ def slerp( ...@@ -158,10 +158,8 @@ def slerp(
return v2 return v2
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor, output_type="np"): def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
batch_size, channels, num_frames, height, width = video.shape batch_size, channels, num_frames, height, width = video.shape
outputs = [] outputs = []
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
...@@ -170,6 +168,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): ...@@ -170,6 +168,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
outputs.append(batch_output) outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs return outputs
...@@ -826,8 +833,8 @@ class AnimateDiffImgToVideoPipeline( ...@@ -826,8 +833,8 @@ class AnimateDiffImgToVideoPipeline(
Examples: Examples:
Returns: Returns:
[`AnimateDiffPipelineOutput`] or `tuple`: [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
""" """
# 0. Default height and width to unet # 0. Default height and width to unet
...@@ -958,11 +965,10 @@ class AnimateDiffImgToVideoPipeline( ...@@ -958,11 +965,10 @@ class AnimateDiffImgToVideoPipeline(
return AnimateDiffPipelineOutput(frames=latents) return AnimateDiffPipelineOutput(frames=latents)
# 10. Post-processing # 10. Post-processing
video_tensor = self.decode_latents(latents) if output_type == "latent":
video = latents
if output_type == "pt":
video = video_tensor
else: else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# 11. Offload all models # 11. Offload all models
......
...@@ -81,7 +81,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: ...@@ -81,7 +81,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs) outputs = torch.stack(outputs)
elif not output_type == "pil": elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs return outputs
...@@ -668,8 +668,8 @@ class AnimateDiffPipeline( ...@@ -668,8 +668,8 @@ class AnimateDiffPipeline(
Examples: Examples:
Returns: Returns:
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
""" """
...@@ -790,6 +790,8 @@ class AnimateDiffPipeline( ...@@ -790,6 +790,8 @@ class AnimateDiffPipeline(
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
...@@ -829,13 +831,14 @@ class AnimateDiffPipeline( ...@@ -829,13 +831,14 @@ class AnimateDiffPipeline(
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 9. Post processing
if output_type == "latent": if output_type == "latent":
return AnimateDiffPipelineOutput(frames=latents) video = latents
else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# 9. Offload all models # 10. Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
......
...@@ -100,7 +100,7 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): ...@@ -100,7 +100,7 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
outputs = torch.stack(outputs) outputs = torch.stack(outputs)
elif not output_type == "pil": elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs return outputs
...@@ -828,8 +828,8 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -828,8 +828,8 @@ class AnimateDiffVideoToVideoPipeline(
Examples: Examples:
Returns: Returns:
[`AnimateDiffPipelineOutput`] or `tuple`: [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
""" """
...@@ -942,6 +942,7 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -942,6 +942,7 @@ class AnimateDiffVideoToVideoPipeline(
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# 8. Denoising loop # 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
...@@ -980,15 +981,11 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -980,15 +981,11 @@ class AnimateDiffVideoToVideoPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if output_type == "latent":
return AnimateDiffPipelineOutput(frames=latents)
# 9. Post-processing # 9. Post-processing
video_tensor = self.decode_latents(latents) if output_type == "latent":
video = latents
if output_type == "pt":
video = video_tensor
else: else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# 10. Offload all models # 10. Offload all models
......
...@@ -83,7 +83,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: ...@@ -83,7 +83,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs) outputs = torch.stack(outputs)
elif not output_type == "pil": elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs return outputs
...@@ -726,13 +726,14 @@ class I2VGenXLPipeline( ...@@ -726,13 +726,14 @@ class I2VGenXLPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
# 8. Post processing
if output_type == "latent": if output_type == "latent":
return I2VGenXLPipelineOutput(frames=latents) video = latents
else:
video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size) video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# Offload all models # 9. Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
......
...@@ -107,7 +107,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: ...@@ -107,7 +107,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs) outputs = torch.stack(outputs)
elif not output_type == "pil": elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs return outputs
...@@ -860,8 +860,8 @@ class PIAPipeline( ...@@ -860,8 +860,8 @@ class PIAPipeline(
Examples: Examples:
Returns: Returns:
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is If `return_dict` is `True`, [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
""" """
# 0. Default height and width to unet # 0. Default height and width to unet
...@@ -1018,13 +1018,14 @@ class PIAPipeline( ...@@ -1018,13 +1018,14 @@ class PIAPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
# 9. Post processing
if output_type == "latent": if output_type == "latent":
return PIAPipelineOutput(frames=latents) video = latents
else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# 9. Offload all models # 10. Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
......
...@@ -74,7 +74,7 @@ def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: s ...@@ -74,7 +74,7 @@ def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: s
outputs = torch.stack(outputs) outputs = torch.stack(outputs)
elif not output_type == "pil": elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs return outputs
......
...@@ -76,7 +76,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: ...@@ -76,7 +76,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs) outputs = torch.stack(outputs)
elif not output_type == "pil": elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs return outputs
...@@ -647,13 +647,14 @@ class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInve ...@@ -647,13 +647,14 @@ class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInve
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
# 8. Post processing
if output_type == "latent": if output_type == "latent":
return TextToVideoSDPipelineOutput(frames=latents) video = latents
else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type) video = tensor2vid(video_tensor, self.image_processor, output_type)
# Offload all models # 9. Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
......
...@@ -111,7 +111,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: ...@@ -111,7 +111,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs) outputs = torch.stack(outputs)
elif not output_type == "pil": elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs return outputs
...@@ -694,13 +694,13 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv ...@@ -694,13 +694,13 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 5. Prepare latent variables # 6. Prepare latent variables
latents = self.prepare_latents(video, latent_timestep, batch_size, prompt_embeds.dtype, device, generator) latents = self.prepare_latents(video, latent_timestep, batch_size, prompt_embeds.dtype, device, generator)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
...@@ -740,20 +740,18 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv ...@@ -740,20 +740,18 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents) callback(step_idx, t, latents)
if output_type == "latent":
return TextToVideoSDPipelineOutput(frames=latents)
# manually for max memory savings # manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu") self.unet.to("cpu")
# 9. Post processing
if output_type == "latent": if output_type == "latent":
return TextToVideoSDPipelineOutput(frames=latents) video = latents
else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type) video = tensor2vid(video_tensor, self.image_processor, output_type)
# Offload all models # 10. Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
if not return_dict: if not return_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment