Unverified Commit 8d891e6e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Torch Compile] Fix torch compile for svd vae (#6217)

parent cce1fe2d
...@@ -25,7 +25,7 @@ from ...image_processor import VaeImageProcessor ...@@ -25,7 +25,7 @@ from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...schedulers import EulerDiscreteScheduler from ...schedulers import EulerDiscreteScheduler
from ...utils import BaseOutput, logging from ...utils import BaseOutput, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -211,7 +211,8 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -211,7 +211,8 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys()) forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
# decode decode_chunk_size frames at a time to avoid OOM # decode decode_chunk_size frames at a time to avoid OOM
frames = [] frames = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment