Unverified Commit 83da817f authored by SahilCarterr's avatar SahilCarterr Committed by GitHub
Browse files

[Add] torch_xla support to pipeline_sana.py (#10364)

[Add] torch_xla support in pipeline_sana.py
parent f430a0cf
......@@ -31,6 +31,7 @@ from ...utils import (
USE_PEFT_BACKEND,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -46,6 +47,13 @@ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from .pipeline_output import SanaPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available():
......@@ -864,6 +872,9 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment