Commit 0513d03d authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #3321 canceled with stages
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, List, Tuple, Callable, Optional, Union
import torch
import torch.distributed
from diffusers import StableDiffusion3Pipeline
from diffusers.utils import is_torch_xla_available
from diffusers.pipelines.stable_diffusion_3.pipeline_output import (
StableDiffusion3PipelineOutput,
)
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import (
retrieve_timesteps,
)
from xfuser.config import EngineConfig, InputConfig
from xfuser.core.distributed import (
get_pipeline_parallel_world_size,
is_pipeline_first_stage,
is_pipeline_last_stage,
get_runtime_state,
get_cfg_group,
get_classifier_free_guidance_world_size,
get_pp_group,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
is_dp_last_group,
)
from .base_pipeline import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
@xFuserPipelineWrapperRegister.register(StableDiffusion3Pipeline)
class xFuserStableDiffusion3Pipeline(xFuserPipelineBaseWrapper):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
engine_config: EngineConfig,
**kwargs,
):
pipeline = StableDiffusion3Pipeline.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
return cls(pipeline, engine_config)
def prepare_run(
self, input_config: InputConfig, steps: int = 3, sync_steps: int = 1
):
prompt = [""] * input_config.batch_size if input_config.batch_size > 1 else ""
warmup_steps = get_runtime_state().runtime_config.warmup_steps
get_runtime_state().runtime_config.warmup_steps = sync_steps
self.__call__(
height=input_config.height,
width=input_config.width,
prompt=prompt,
num_inference_steps=steps,
generator=torch.Generator(device="cuda").manual_seed(42),
output_type=input_config.output_type,
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@xFuserPipelineBaseWrapper.enable_data_parallel
@xFuserPipelineBaseWrapper.check_to_use_naive_forward
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead
prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
will be used instead
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used instead
negative_prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
`text_encoder_3`. If not defined, `negative_prompt` is used instead
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
prompt_3,
height,
width,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
#! ---------------------------------------- ADDED BELOW ----------------------------------------
# * set runtime state input parameters
get_runtime_state().set_input_parameters(
height=height,
width=width,
batch_size=batch_size,
num_inference_steps=num_inference_steps,
split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1,
)
#! ---------------------------------------- ADDED ABOVE ----------------------------------------
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_3=prompt_3,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt,
)
if self.do_classifier_free_guidance:
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
(
prompt_embeds,
pooled_prompt_embeds,
) = self._process_cfg_split_batch(
negative_prompt_embeds,
prompt_embeds,
negative_pooled_prompt_embeds,
pooled_prompt_embeds,
)
#! ORIGIN
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps
)
num_warmup_steps = max(
len(timesteps) - num_inference_steps * self.scheduler.order, 0
)
self._num_timesteps = len(timesteps)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Denoising loop
num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps
with self.progress_bar(total=num_inference_steps) as progress_bar:
if (
get_pipeline_parallel_world_size() > 1
and len(timesteps) > num_pipeline_warmup_steps
):
# * warmup stage
latents = self._sync_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
timesteps=timesteps[:num_pipeline_warmup_steps],
num_warmup_steps=num_warmup_steps,
progress_bar=progress_bar,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
# * pipefusion stage
latents = self._async_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
timesteps=timesteps[num_pipeline_warmup_steps:],
num_warmup_steps=num_warmup_steps,
progress_bar=progress_bar,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
else:
latents = self._sync_pipeline(
latents=latents,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
timesteps=timesteps,
num_warmup_steps=num_warmup_steps,
progress_bar=progress_bar,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
sync_only=True,
)
# * 8. Decode latents (only the last rank in a dp group)
def vae_decode(latents):
latents = (
latents / self.vae.config.scaling_factor
) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
return image
if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = self.gather_broadcast_latents(latents)
image = vae_decode(latents)
else:
if is_dp_last_group():
image = vae_decode(latents)
if self.is_dp_last_group():
if output_type == "latent":
image = latents
else:
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return StableDiffusion3PipelineOutput(images=image)
else:
return None
def _init_sync_pipeline(self, latents: torch.Tensor, prompt_embeds: torch.Tensor):
get_runtime_state().set_patched_mode(patch_mode=False)
latents_list = [
latents[:, :, start_idx:end_idx, :]
for start_idx, end_idx in get_runtime_state().pp_patches_start_end_idx_global
]
latents = torch.cat(latents_list, dim=-2)
if get_runtime_state().split_text_embed_in_sp:
if prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0:
prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
else:
get_runtime_state().split_text_embed_in_sp = False
return latents, prompt_embeds
# synchronized compute the whole feature map in each pp stage
def _sync_pipeline(
self,
latents: torch.Tensor,
prompt_embeds: torch.Tensor,
pooled_prompt_embeds: torch.Tensor,
timesteps: List[int],
num_warmup_steps: int,
progress_bar,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
sync_only: bool = False,
):
latents, prompt_embeds = self._init_sync_pipeline(latents, prompt_embeds)
for i, t in enumerate(timesteps):
if self.interrupt:
continue
if is_pipeline_last_stage():
last_timestep_latents = latents
# when there is only one pp stage, no need to recv
if get_pipeline_parallel_world_size() == 1:
pass
# all ranks should recv the latent from the previous rank except
# the first rank in the first pipeline forward which should use
# the input latent
elif is_pipeline_first_stage() and i == 0:
pass
else:
latents = get_pp_group().pipeline_recv()
if not is_pipeline_first_stage():
encoder_hidden_states = get_pp_group().pipeline_recv(
0, "encoder_hidden_states"
)
latents, encoder_hidden_states = self._backbone_forward(
latents=latents,
encoder_hidden_states=(
prompt_embeds
if is_pipeline_first_stage()
else encoder_hidden_states
),
pooled_prompt_embeds=pooled_prompt_embeds,
t=t,
)
if is_pipeline_last_stage():
latents_dtype = latents.dtype
latents = self._scheduler_step(latents, last_timestep_latents, t)
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop(
"negative_prompt_embeds", negative_prompt_embeds
)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if sync_only and is_pipeline_last_stage() and i == len(timesteps) - 1:
pass
elif get_pipeline_parallel_world_size() > 1:
get_pp_group().pipeline_send(latents)
if not is_pipeline_last_stage():
get_pp_group().pipeline_send(
encoder_hidden_states, name="encoder_hidden_states"
)
if (
sync_only
and get_sequence_parallel_world_size() > 1
and is_pipeline_last_stage()
):
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True)
latents_list = []
for pp_patch_idx in range(get_runtime_state().num_pipeline_patch):
latents_list += [
sp_latents_list[sp_patch_idx][
:,
:,
get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
latents = torch.cat(latents_list, dim=-2)
return latents
def _init_async_pipeline(
self,
num_timesteps: int,
latents: torch.Tensor,
num_pipeline_warmup_steps: int,
):
get_runtime_state().set_patched_mode(patch_mode=True)
if is_pipeline_first_stage():
# get latents computed in warmup stage
# ignore latents after the last timestep
latents = (
get_pp_group().pipeline_recv()
if num_pipeline_warmup_steps > 0
else latents
)
patch_latents = list(
latents.split(get_runtime_state().pp_patches_height, dim=2)
)
elif is_pipeline_last_stage():
patch_latents = list(
latents.split(get_runtime_state().pp_patches_height, dim=2)
)
else:
patch_latents = [
None for _ in range(get_runtime_state().num_pipeline_patch)
]
recv_timesteps = (
num_timesteps - 1 if is_pipeline_first_stage() else num_timesteps
)
if is_pipeline_first_stage():
for _ in range(recv_timesteps):
for patch_idx in range(get_runtime_state().num_pipeline_patch):
get_pp_group().add_pipeline_recv_task(patch_idx)
else:
for _ in range(recv_timesteps):
get_pp_group().add_pipeline_recv_task(0, "encoder_hidden_states")
for patch_idx in range(get_runtime_state().num_pipeline_patch):
get_pp_group().add_pipeline_recv_task(patch_idx)
return patch_latents
# * implement of pipefusion
def _async_pipeline(
self,
latents: torch.Tensor,
prompt_embeds: torch.Tensor,
pooled_prompt_embeds: torch.Tensor,
timesteps: List[int],
num_warmup_steps: int,
progress_bar,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
if len(timesteps) == 0:
return latents
num_pipeline_patch = get_runtime_state().num_pipeline_patch
num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps
patch_latents = self._init_async_pipeline(
num_timesteps=len(timesteps),
latents=latents,
num_pipeline_warmup_steps=num_pipeline_warmup_steps,
)
last_patch_latents = (
[None for _ in range(num_pipeline_patch)]
if (is_pipeline_last_stage())
else None
)
first_async_recv = True
for i, t in enumerate(timesteps):
if self.interrupt:
continue
for patch_idx in range(num_pipeline_patch):
if is_pipeline_last_stage():
last_patch_latents[patch_idx] = patch_latents[patch_idx]
if is_pipeline_first_stage() and i == 0:
pass
else:
if first_async_recv:
if not is_pipeline_first_stage() and patch_idx == 0:
get_pp_group().recv_next()
get_pp_group().recv_next()
first_async_recv = False
if not is_pipeline_first_stage() and patch_idx == 0:
last_encoder_hidden_states = (
get_pp_group().get_pipeline_recv_data(
idx=patch_idx, name="encoder_hidden_states"
)
)
patch_latents[patch_idx] = get_pp_group().get_pipeline_recv_data(
idx=patch_idx
)
patch_latents[patch_idx], next_encoder_hidden_states = (
self._backbone_forward(
latents=patch_latents[patch_idx],
encoder_hidden_states=(
prompt_embeds
if is_pipeline_first_stage()
else last_encoder_hidden_states
),
pooled_prompt_embeds=pooled_prompt_embeds,
t=t,
)
)
if is_pipeline_last_stage():
latents_dtype = patch_latents[patch_idx].dtype
patch_latents[patch_idx] = self._scheduler_step(
patch_latents[patch_idx],
last_patch_latents[patch_idx],
t,
)
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(
self, i, t, callback_kwargs
)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop(
"prompt_embeds", prompt_embeds
)
negative_prompt_embeds = callback_outputs.pop(
"negative_prompt_embeds", negative_prompt_embeds
)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds",
negative_pooled_prompt_embeds,
)
if i != len(timesteps) - 1:
get_pp_group().pipeline_isend(
patch_latents[patch_idx], segment_idx=patch_idx
)
else:
if patch_idx == 0:
get_pp_group().pipeline_isend(
next_encoder_hidden_states, name="encoder_hidden_states"
)
get_pp_group().pipeline_isend(
patch_latents[patch_idx], segment_idx=patch_idx
)
if is_pipeline_first_stage() and i == 0:
pass
else:
if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1:
pass
elif is_pipeline_first_stage():
get_pp_group().recv_next()
else:
# recv encoder_hidden_state
if patch_idx == num_pipeline_patch - 1:
get_pp_group().recv_next()
# recv latents
get_pp_group().recv_next()
get_runtime_state().next_patch()
if i == len(timesteps) - 1 or (
(i + num_pipeline_warmup_steps + 1) > num_warmup_steps
and (i + num_pipeline_warmup_steps + 1) % self.scheduler.order == 0
):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
latents = None
if is_pipeline_last_stage():
latents = torch.cat(patch_latents, dim=2)
if get_sequence_parallel_world_size() > 1:
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(
latents, separate_tensors=True
)
latents_list = []
for pp_patch_idx in range(get_runtime_state().num_pipeline_patch):
latents_list += [
sp_latents_list[sp_patch_idx][
...,
get_runtime_state()
.pp_patches_start_idx_local[
pp_patch_idx
] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
latents = torch.cat(latents_list, dim=-2)
return latents
def _backbone_forward(
self,
latents: torch.Tensor,
encoder_hidden_states: torch.Tensor,
pooled_prompt_embeds: torch.Tensor,
t: Union[float, torch.Tensor],
):
if is_pipeline_first_stage():
latents = torch.cat(
[latents] * (2 // get_classifier_free_guidance_world_size())
)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
noise_pred, encoder_hidden_states = self.transformer(
hidden_states=latents,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# classifier free guidance
if is_pipeline_last_stage():
if get_classifier_free_guidance_world_size() == 1:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
elif get_classifier_free_guidance_world_size() == 2:
noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather(
noise_pred, separate_tensors=True
)
latents = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
else:
latents = noise_pred
return latents, encoder_hidden_states
def _scheduler_step(
self,
noise_pred: torch.Tensor,
latents: torch.Tensor,
t: Union[float, torch.Tensor],
):
return self.scheduler.step(
noise_pred,
t,
latents,
return_dict=False,
)[0]
from typing import Dict, Type, Union
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from xfuser.logger import init_logger
from .base_pipeline import xFuserPipelineBaseWrapper
logger = init_logger(__name__)
class xFuserPipelineWrapperRegister:
_XFUSER_PIPE_MAPPING: Dict[
Type[DiffusionPipeline],
Type[xFuserPipelineBaseWrapper]
] = {}
@classmethod
def register(cls, origin_pipe_class: Type[DiffusionPipeline]):
def decorator(xfuser_pipe_class: Type[xFuserPipelineBaseWrapper]):
if not issubclass(xfuser_pipe_class, xFuserPipelineBaseWrapper):
raise ValueError(f"{xfuser_pipe_class} is not a subclass of"
f" xFuserPipelineBaseWrapper")
cls._XFUSER_PIPE_MAPPING[origin_pipe_class] = \
xfuser_pipe_class
return xfuser_pipe_class
return decorator
@classmethod
def get_class(
cls,
pipe: Union[DiffusionPipeline, Type[DiffusionPipeline]]
) -> Type[xFuserPipelineBaseWrapper]:
if isinstance(pipe, type):
candidate = None
candidate_origin = None
for (origin_model_class,
xfuser_model_class) in cls._XFUSER_PIPE_MAPPING.items():
if issubclass(pipe, origin_model_class):
if ((candidate is None and candidate_origin is None) or
issubclass(origin_model_class, candidate_origin)):
candidate_origin = origin_model_class
candidate = xfuser_model_class
if candidate is None:
raise ValueError(f"Diffusion Pipeline class {pipe} "
f"is not supported by xFuser")
else:
return candidate
elif isinstance(pipe, DiffusionPipeline):
candidate = None
candidate_origin = None
for (origin_model_class,
xfuser_model_class) in cls._XFUSER_PIPE_MAPPING.items():
if isinstance(pipe, origin_model_class):
if ((candidate is None and candidate_origin is None) or
issubclass(origin_model_class, candidate_origin)):
candidate_origin = origin_model_class
candidate = xfuser_model_class
if candidate is None:
raise ValueError(f"Diffusion Pipeline class {pipe.__class__} "
f"is not supported by xFuser")
else:
return candidate
else:
raise ValueError(f"Unsupported type {type(pipe)} for pipe")
\ No newline at end of file
from .register import xFuserSchedulerWrappersRegister
from .base_scheduler import xFuserSchedulerBaseWrapper
from .scheduling_dpmsolver_multistep import (
xFuserDPMSolverMultistepSchedulerWrapper
)
from .scheduling_flow_match_euler_discrete import (
xFuserFlowMatchEulerDiscreteSchedulerWrapper,
)
from .scheduling_ddim import xFuserDDIMSchedulerWrapper
from .scheduling_ddpm import xFuserDDPMSchedulerWrapper
from .scheduling_ddim_cogvideox import xFuserCogVideoXDDIMSchedulerWrapper
from .scheduling_dpm_cogvideox import xFuserCogVideoXDPMSchedulerWrapper
__all__ = [
"xFuserSchedulerWrappersRegister",
"xFuserSchedulerBaseWrapper",
"xFuserDPMSolverMultistepSchedulerWrapper",
"xFuserFlowMatchEulerDiscreteSchedulerWrapper",
"xFuserDDIMSchedulerWrapper",
"xFuserCogVideoXDDIMSchedulerWrapper",
"xFuserCogVideoXDPMSchedulerWrapper",
"xFuserDDPMSchedulerWrapper",
]
\ No newline at end of file
from abc import abstractmethod, ABCMeta
from functools import wraps
from typing import List
from diffusers.schedulers import SchedulerMixin
from xfuser.core.distributed import (
get_pipeline_parallel_world_size,
get_sequence_parallel_world_size,
)
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper
class xFuserSchedulerBaseWrapper(xFuserBaseWrapper, metaclass=ABCMeta):
def __init__(
self,
module: SchedulerMixin,
):
super().__init__(
module=module,
)
def __setattr__(self, name, value):
if name == "module":
super().__setattr__(name, value)
elif (
hasattr(self, "module")
and self.module is not None
and hasattr(self.module, name)
):
setattr(self.module, name, value)
else:
super().__setattr__(name, value)
@abstractmethod
def step(self, *args, **kwargs):
pass
@staticmethod
def check_to_use_naive_step(func):
@wraps(func)
def check_naive_step_fn(self, *args, **kwargs):
if (
get_pipeline_parallel_world_size() == 1
and get_sequence_parallel_world_size() == 1
):
return self.module.step(*args, **kwargs)
else:
return func(self, *args, **kwargs)
return check_naive_step_fn
from typing import Dict, Type
import torch
import torch.nn as nn
from xfuser.logger import init_logger
from xfuser.model_executor.schedulers.base_scheduler import xFuserSchedulerBaseWrapper
logger = init_logger(__name__)
class xFuserSchedulerWrappersRegister:
_XFUSER_SCHEDULER_MAPPING: Dict[
Type[nn.Module],
Type[xFuserSchedulerBaseWrapper]
] = {}
@classmethod
def register(cls, origin_scheduler_class: Type[nn.Module]):
def decorator(xfuser_scheduler_class: Type[nn.Module]):
if not issubclass(xfuser_scheduler_class,
xFuserSchedulerBaseWrapper):
raise ValueError(
f"{xfuser_scheduler_class.__class__.__name__} is not "
f"a subclass of xFuserSchedulerBaseWrapper"
)
cls._XFUSER_SCHEDULER_MAPPING[origin_scheduler_class] = \
xfuser_scheduler_class
return xfuser_scheduler_class
return decorator
@classmethod
def get_wrapper(
cls,
scheduler: nn.Module
) -> xFuserSchedulerBaseWrapper:
candidate = None
candidate_origin = None
for (origin_scheduler_class,
wrapper_class) in cls._XFUSER_SCHEDULER_MAPPING.items():
if isinstance(scheduler, origin_scheduler_class):
if ((candidate is None and candidate_origin is None) or
origin_scheduler_class == scheduler.__class__ or
issubclass(origin_scheduler_class, candidate_origin)):
candidate_origin = origin_scheduler_class
candidate = wrapper_class
if candidate is None:
raise ValueError(f"Scheduler class {scheduler.__class__.__name__} "
f"is not supported by xFuser")
else:
return candidate
\ No newline at end of file
from typing import Optional, Tuple, Union
import torch
import torch.distributed
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_ddim import (
DDIMScheduler,
DDIMSchedulerOutput,
)
from xfuser.core.distributed import (
get_pipeline_parallel_world_size,
get_sequence_parallel_world_size,
get_runtime_state,
)
from .register import xFuserSchedulerWrappersRegister
from .base_scheduler import xFuserSchedulerBaseWrapper
@xFuserSchedulerWrappersRegister.register(DDIMScheduler)
class xFuserDDIMSchedulerWrapper(xFuserSchedulerBaseWrapper):
@xFuserSchedulerBaseWrapper.check_to_use_naive_step
def step(
self,
*args,
**kwargs,
) -> Union[DDIMSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
The weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
return self.module.step(*args, **kwargs)
from typing import Optional, Tuple, Union
import torch
import torch.distributed
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_ddim_cogvideox import (
CogVideoXDDIMScheduler,
DDIMSchedulerOutput,
)
from .register import xFuserSchedulerWrappersRegister
from .base_scheduler import xFuserSchedulerBaseWrapper
@xFuserSchedulerWrappersRegister.register(CogVideoXDDIMScheduler)
class xFuserCogVideoXDDIMSchedulerWrapper(xFuserSchedulerBaseWrapper):
@xFuserSchedulerBaseWrapper.check_to_use_naive_step
def step(
self,
*args,
**kwargs,
) -> Union[DDIMSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
The weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
return self.module.step(*args, **kwargs)
from typing import Optional, Tuple, Union
import torch
import torch.distributed
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_ddpm import (
DDPMScheduler,
DDPMSchedulerOutput,
)
from xfuser.core.distributed import (
get_pipeline_parallel_world_size,
get_sequence_parallel_world_size,
get_runtime_state,
)
from .register import xFuserSchedulerWrappersRegister
from .base_scheduler import xFuserSchedulerBaseWrapper
@xFuserSchedulerWrappersRegister.register(DDPMScheduler)
class xFuserDDPMSchedulerWrapper(xFuserSchedulerBaseWrapper):
@xFuserSchedulerBaseWrapper.check_to_use_naive_step
def step(
self,
*args,
generator=None,
**kwargs,
) -> Union[DDPMSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
return self.module.step(*args, generator, **kwargs)
from typing import Optional, Tuple, Union
import torch
import torch.distributed
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_dpm_cogvideox import (
CogVideoXDPMScheduler,
DDIMSchedulerOutput,
)
from .register import xFuserSchedulerWrappersRegister
from .base_scheduler import xFuserSchedulerBaseWrapper
@xFuserSchedulerWrappersRegister.register(CogVideoXDPMScheduler)
class xFuserCogVideoXDPMSchedulerWrapper(xFuserSchedulerBaseWrapper):
@xFuserSchedulerBaseWrapper.check_to_use_naive_step
def step(
self,
*args,
**kwargs,
) -> Union[DDIMSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
The weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
return self.module.step(*args, **kwargs)
from typing import Optional, Tuple, Union
import torch
import torch.distributed
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_dpmsolver_multistep import (
DPMSolverMultistepScheduler,
SchedulerOutput,
)
from xfuser.core.distributed import get_runtime_state
from .register import xFuserSchedulerWrappersRegister
from .base_scheduler import xFuserSchedulerBaseWrapper
@xFuserSchedulerWrappersRegister.register(DPMSolverMultistepScheduler)
class xFuserDPMSolverMultistepSchedulerWrapper(xFuserSchedulerBaseWrapper):
@xFuserSchedulerBaseWrapper.check_to_use_naive_step
def step(
self,
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
the multistep DPMSolver.
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`LEdits++`].
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
# Improve numerical stability for small number of steps
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final
or (self.config.lower_order_final and len(self.timesteps) < 15)
or self.config.final_sigmas_type == "zero"
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2)
and self.config.lower_order_final
and len(self.timesteps) < 15
)
model_output = self.convert_model_output(model_output, sample=sample)
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
if (
get_runtime_state().patch_mode
and get_runtime_state().pipeline_patch_idx == 0
and self.model_outputs[-1] is None
):
self.model_outputs[-1] = torch.zeros(
[
model_output.shape[0],
model_output.shape[1],
get_runtime_state().pp_patches_start_idx_local[-1],
model_output.shape[3],
],
device=model_output.device,
dtype=model_output.dtype,
)
if get_runtime_state().pipeline_patch_idx == 0:
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
if (
get_runtime_state().patch_mode
and get_runtime_state().pipeline_patch_idx == 0
):
assert len(self.model_outputs) >= 2
self.model_outputs[-1] = torch.zeros_like(self.model_outputs[-2])
if get_runtime_state().patch_mode:
self.model_outputs[-1][
:,
:,
get_runtime_state()
.pp_patches_start_idx_local[
get_runtime_state().pipeline_patch_idx
] : get_runtime_state()
.pp_patches_start_idx_local[get_runtime_state().pipeline_patch_idx + 1],
:,
] = model_output
else:
self.model_outputs[-1] = model_output
#! ORIGIN:
# for i in range(self.config.solver_order - 1):
# self.model_outputs[i] = self.model_outputs[i + 1]
# self.model_outputs[-1] = model_output
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
if (
self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]
and variance_noise is None
):
noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=torch.float32,
)
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
else:
noise = None
#! ---------------------------------------- ADD BELOW ----------------------------------------
if get_runtime_state().patch_mode:
model_outputs = []
for output in self.model_outputs:
model_outputs.append(
output[
:,
:,
get_runtime_state()
.pp_patches_start_idx_local[
get_runtime_state().pipeline_patch_idx
] : get_runtime_state()
.pp_patches_start_idx_local[
get_runtime_state().pipeline_patch_idx + 1
],
:,
]
)
else:
model_outputs = self.model_outputs
#! ---------------------------------------- ADD ABOVE ----------------------------------------
if (
self.config.solver_order == 1
or self.lower_order_nums < 1
or lower_order_final
):
prev_sample = self.dpm_solver_first_order_update(
model_output, sample=sample, noise=noise
)
elif (
self.config.solver_order == 2
or self.lower_order_nums < 2
or lower_order_second
):
prev_sample = self.multistep_dpm_solver_second_order_update(
model_outputs, sample=sample, noise=noise
)
else:
prev_sample = self.multistep_dpm_solver_third_order_update(
model_outputs, sample=sample
)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# Cast sample back to expected dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
# * increase step index only when the last pipeline patch is done (or not in patch mode)
if (
not get_runtime_state().patch_mode
or get_runtime_state().pipeline_patch_idx
== get_runtime_state().num_pipeline_patch - 1
):
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
from typing import Optional, Tuple, Union
import torch
import torch.distributed
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.schedulers.scheduling_flow_match_euler_discrete import (
FlowMatchEulerDiscreteSchedulerOutput,
)
from xfuser.core.distributed import get_runtime_state
from .register import xFuserSchedulerWrappersRegister
from .base_scheduler import xFuserSchedulerBaseWrapper
@xFuserSchedulerWrappersRegister.register(FlowMatchEulerDiscreteScheduler)
class xFuserFlowMatchEulerDiscreteSchedulerWrapper(xFuserSchedulerBaseWrapper):
@xFuserSchedulerBaseWrapper.check_to_use_naive_step
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
sigma = self.sigmas[self.step_index]
gamma = (
min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
if s_tmin <= sigma <= s_tmax
else 0.0
)
noise = randn_tensor(
model_output.shape,
dtype=model_output.dtype,
device=model_output.device,
generator=generator,
)
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
# backwards compatibility
# if self.config.prediction_type == "vector_field":
denoised = sample - model_output * sigma
# 2. Convert to an ODE derivative
derivative = (sample - denoised) / sigma_hat
dt = self.sigmas[self.step_index + 1] - sigma_hat
prev_sample = sample + derivative * dt
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
if (
not get_runtime_state().patch_mode
or get_runtime_state().pipeline_patch_idx
== get_runtime_state().num_pipeline_patch - 1
):
self._step_index += 1
if not return_dict:
return (prev_sample,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
import os
from pathlib import Path
from xfuser.config.config import InputConfig
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
from xfuser.config import EngineConfig
from xfuser.core.distributed.parallel_state import (
get_data_parallel_rank,
get_data_parallel_world_size,
is_dp_last_group,
)
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.logger import init_logger
from xfuser.model_executor.pipelines.base_pipeline import xFuserPipelineBaseWrapper
from xfuser.model_executor.pipelines.register import xFuserPipelineWrapperRegister
logger = init_logger(__name__)
class xDiTParallel:
def __init__(self, pipe, engine_config: EngineConfig, input_config: InputConfig):
xfuser_pipe_wrapper = xFuserPipelineWrapperRegister.get_class(pipe)
self.pipe = xfuser_pipe_wrapper(pipeline=pipe, engine_config=engine_config)
self.config = engine_config
self.pipe.prepare_run(input_config)
def __call__(
self,
*args,
**kwargs,
):
self.result = self.pipe(*args, **kwargs)
return self.result
def save(self, directory: str, prefix: str):
dp_rank = get_data_parallel_rank()
parallel_info = (
f"dp{self.config.parallel_config.dp_degree}_cfg{self.config.parallel_config.cfg_degree}_"
f"ulysses{self.config.parallel_config.ulysses_degree}_ring{self.config.parallel_config.ring_degree}_"
f"pp{self.config.parallel_config.pp_degree}_patch{self.config.parallel_config.pp_config.num_pipeline_patch}"
)
if is_dp_last_group():
path = Path(f"{directory}")
path.mkdir(mode=755, parents=True, exist_ok=True)
path = path / f"{prefix}_result_{parallel_info}_dprank{dp_rank}"
for i, image in enumerate(self.result.images):
image.save(f"{str(path)}_image{i}.png")
print(f"{str(path)}_image{i}.png")
def __del__(self):
get_runtime_state().destory_distributed_env()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment