Unverified Commit 4520e122 authored by Junsong Chen's avatar Junsong Chen Committed by GitHub
Browse files

adapt PixArtAlphaPipeline for pixart-lcm model (#5974)



* adapt PixArtAlphaPipeline for pixart-lcm model

* remove original_inference_steps from __call__

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 61826040
...@@ -134,6 +134,51 @@ ASPECT_RATIO_512_BIN = { ...@@ -134,6 +134,51 @@ ASPECT_RATIO_512_BIN = {
} }
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class PixArtAlphaPipeline(DiffusionPipeline): class PixArtAlphaPipeline(DiffusionPipeline):
r""" r"""
Pipeline for text-to-image generation using PixArt-Alpha. Pipeline for text-to-image generation using PixArt-Alpha.
...@@ -783,8 +828,7 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -783,8 +828,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps = self.scheduler.timesteps
# 5. Prepare latents. # 5. Prepare latents.
latent_channels = self.transformer.config.in_channels latent_channels = self.transformer.config.in_channels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment