"docs/vscode:/vscode.git/clone" did not exist on "aef00c4d4d31ed63db1d80f234daba9f876dfe5a"
Unverified Commit a17832b2 authored by chaowenguo's avatar chaowenguo Committed by GitHub
Browse files

add pythor_xla support for render a video (#10443)



* Update rerender_a_video.py

* Update rerender_a_video.py

* make style

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent c28db0aa
...@@ -30,10 +30,17 @@ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel ...@@ -30,10 +30,17 @@ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import BaseOutput, deprecate, logging from diffusers.utils import BaseOutput, deprecate, is_torch_xla_available, logging
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1100,6 +1107,9 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline): ...@@ -1100,6 +1107,9 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
return latents return latents
if mask_start_t <= mask_end_t: if mask_start_t <= mask_end_t:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment