Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
HunyuanVideo-T2V
Commits
0513d03d
Commit
0513d03d
authored
Feb 03, 2026
by
jerrrrry
Browse files
Initial commit
parents
Pipeline
#3321
canceled with stages
Changes
152
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1599 additions
and
0 deletions
+1599
-0
xfuser1/model_executor/pipelines/pipeline_stable_diffusion_3.py
...1/model_executor/pipelines/pipeline_stable_diffusion_3.py
+799
-0
xfuser1/model_executor/pipelines/register.py
xfuser1/model_executor/pipelines/register.py
+64
-0
xfuser1/model_executor/schedulers/__init__.py
xfuser1/model_executor/schedulers/__init__.py
+24
-0
xfuser1/model_executor/schedulers/base_scheduler.py
xfuser1/model_executor/schedulers/base_scheduler.py
+50
-0
xfuser1/model_executor/schedulers/register.py
xfuser1/model_executor/schedulers/register.py
+51
-0
xfuser1/model_executor/schedulers/scheduling_ddim.py
xfuser1/model_executor/schedulers/scheduling_ddim.py
+62
-0
xfuser1/model_executor/schedulers/scheduling_ddim_cogvideox.py
...r1/model_executor/schedulers/scheduling_ddim_cogvideox.py
+57
-0
xfuser1/model_executor/schedulers/scheduling_ddpm.py
xfuser1/model_executor/schedulers/scheduling_ddpm.py
+53
-0
xfuser1/model_executor/schedulers/scheduling_dpm_cogvideox.py
...er1/model_executor/schedulers/scheduling_dpm_cogvideox.py
+57
-0
xfuser1/model_executor/schedulers/scheduling_dpmsolver_multistep.py
...del_executor/schedulers/scheduling_dpmsolver_multistep.py
+201
-0
xfuser1/model_executor/schedulers/scheduling_flow_match_euler_discrete.py
...ecutor/schedulers/scheduling_flow_match_euler_discrete.py
+127
-0
xfuser1/parallel.py
xfuser1/parallel.py
+54
-0
No files found.
xfuser1/model_executor/pipelines/pipeline_stable_diffusion_3.py
0 → 100755
View file @
0513d03d
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Callable
,
Optional
,
Union
import
torch
import
torch.distributed
from
diffusers
import
StableDiffusion3Pipeline
from
diffusers.utils
import
is_torch_xla_available
from
diffusers.pipelines.stable_diffusion_3.pipeline_output
import
(
StableDiffusion3PipelineOutput
,
)
from
diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3
import
(
retrieve_timesteps
,
)
from
xfuser.config
import
EngineConfig
,
InputConfig
from
xfuser.core.distributed
import
(
get_pipeline_parallel_world_size
,
is_pipeline_first_stage
,
is_pipeline_last_stage
,
get_runtime_state
,
get_cfg_group
,
get_classifier_free_guidance_world_size
,
get_pp_group
,
get_sequence_parallel_world_size
,
get_sequence_parallel_rank
,
get_sp_group
,
is_dp_last_group
,
)
from
.base_pipeline
import
xFuserPipelineBaseWrapper
from
.register
import
xFuserPipelineWrapperRegister
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
@
xFuserPipelineWrapperRegister
.
register
(
StableDiffusion3Pipeline
)
class
xFuserStableDiffusion3Pipeline
(
xFuserPipelineBaseWrapper
):
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
engine_config
:
EngineConfig
,
**
kwargs
,
):
pipeline
=
StableDiffusion3Pipeline
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
return
cls
(
pipeline
,
engine_config
)
def
prepare_run
(
self
,
input_config
:
InputConfig
,
steps
:
int
=
3
,
sync_steps
:
int
=
1
):
prompt
=
[
""
]
*
input_config
.
batch_size
if
input_config
.
batch_size
>
1
else
""
warmup_steps
=
get_runtime_state
().
runtime_config
.
warmup_steps
get_runtime_state
().
runtime_config
.
warmup_steps
=
sync_steps
self
.
__call__
(
height
=
input_config
.
height
,
width
=
input_config
.
width
,
prompt
=
prompt
,
num_inference_steps
=
steps
,
generator
=
torch
.
Generator
(
device
=
"cuda"
).
manual_seed
(
42
),
output_type
=
input_config
.
output_type
,
)
get_runtime_state
().
runtime_config
.
warmup_steps
=
warmup_steps
@
property
def
guidance_scale
(
self
):
return
self
.
_guidance_scale
@
property
def
clip_skip
(
self
):
return
self
.
_clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@
property
def
do_classifier_free_guidance
(
self
):
return
self
.
_guidance_scale
>
1
@
property
def
joint_attention_kwargs
(
self
):
return
self
.
_joint_attention_kwargs
@
property
def
num_timesteps
(
self
):
return
self
.
_num_timesteps
@
property
def
interrupt
(
self
):
return
self
.
_interrupt
@
torch
.
no_grad
()
@
xFuserPipelineBaseWrapper
.
enable_data_parallel
@
xFuserPipelineBaseWrapper
.
check_to_use_naive_forward
def
__call__
(
self
,
prompt
:
Union
[
str
,
List
[
str
]]
=
None
,
prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
prompt_3
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
num_inference_steps
:
int
=
28
,
timesteps
:
List
[
int
]
=
None
,
guidance_scale
:
float
=
7.0
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
negative_prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
negative_prompt_3
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
num_images_per_prompt
:
Optional
[
int
]
=
1
,
generator
:
Optional
[
Union
[
torch
.
Generator
,
List
[
torch
.
Generator
]]]
=
None
,
latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
negative_prompt_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
pooled_prompt_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
negative_pooled_prompt_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
joint_attention_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
clip_skip
:
Optional
[
int
]
=
None
,
callback_on_step_end
:
Optional
[
Callable
[[
int
,
int
,
Dict
],
None
]]
=
None
,
callback_on_step_end_tensor_inputs
:
List
[
str
]
=
[
"latents"
],
**
kwargs
,
):
r
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead
prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
will be used instead
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used instead
negative_prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
`text_encoder_3`. If not defined, `negative_prompt` is used instead
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
height
=
height
or
self
.
default_sample_size
*
self
.
vae_scale_factor
width
=
width
or
self
.
default_sample_size
*
self
.
vae_scale_factor
# 1. Check inputs. Raise error if not correct
self
.
check_inputs
(
prompt
,
prompt_2
,
prompt_3
,
height
,
width
,
negative_prompt
=
negative_prompt
,
negative_prompt_2
=
negative_prompt_2
,
negative_prompt_3
=
negative_prompt_3
,
prompt_embeds
=
prompt_embeds
,
negative_prompt_embeds
=
negative_prompt_embeds
,
pooled_prompt_embeds
=
pooled_prompt_embeds
,
negative_pooled_prompt_embeds
=
negative_pooled_prompt_embeds
,
callback_on_step_end_tensor_inputs
=
callback_on_step_end_tensor_inputs
,
)
self
.
_guidance_scale
=
guidance_scale
self
.
_clip_skip
=
clip_skip
self
.
_joint_attention_kwargs
=
joint_attention_kwargs
self
.
_interrupt
=
False
# 2. Define call parameters
if
prompt
is
not
None
and
isinstance
(
prompt
,
str
):
batch_size
=
1
elif
prompt
is
not
None
and
isinstance
(
prompt
,
list
):
batch_size
=
len
(
prompt
)
else
:
batch_size
=
prompt_embeds
.
shape
[
0
]
device
=
self
.
_execution_device
#! ---------------------------------------- ADDED BELOW ----------------------------------------
# * set runtime state input parameters
get_runtime_state
().
set_input_parameters
(
height
=
height
,
width
=
width
,
batch_size
=
batch_size
,
num_inference_steps
=
num_inference_steps
,
split_text_embed_in_sp
=
get_pipeline_parallel_world_size
()
==
1
,
)
#! ---------------------------------------- ADDED ABOVE ----------------------------------------
(
prompt_embeds
,
negative_prompt_embeds
,
pooled_prompt_embeds
,
negative_pooled_prompt_embeds
,
)
=
self
.
encode_prompt
(
prompt
=
prompt
,
prompt_2
=
prompt_2
,
prompt_3
=
prompt_3
,
negative_prompt
=
negative_prompt
,
negative_prompt_2
=
negative_prompt_2
,
negative_prompt_3
=
negative_prompt_3
,
do_classifier_free_guidance
=
self
.
do_classifier_free_guidance
,
prompt_embeds
=
prompt_embeds
,
negative_prompt_embeds
=
negative_prompt_embeds
,
pooled_prompt_embeds
=
pooled_prompt_embeds
,
negative_pooled_prompt_embeds
=
negative_pooled_prompt_embeds
,
device
=
device
,
clip_skip
=
self
.
clip_skip
,
num_images_per_prompt
=
num_images_per_prompt
,
)
if
self
.
do_classifier_free_guidance
:
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
(
prompt_embeds
,
pooled_prompt_embeds
,
)
=
self
.
_process_cfg_split_batch
(
negative_prompt_embeds
,
prompt_embeds
,
negative_pooled_prompt_embeds
,
pooled_prompt_embeds
,
)
#! ORIGIN
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
# 4. Prepare timesteps
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
timesteps
)
num_warmup_steps
=
max
(
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
,
0
)
self
.
_num_timesteps
=
len
(
timesteps
)
# 5. Prepare latent variables
num_channels_latents
=
self
.
transformer
.
config
.
in_channels
latents
=
self
.
prepare_latents
(
batch_size
*
num_images_per_prompt
,
num_channels_latents
,
height
,
width
,
prompt_embeds
.
dtype
,
device
,
generator
,
latents
,
)
# 6. Denoising loop
num_pipeline_warmup_steps
=
get_runtime_state
().
runtime_config
.
warmup_steps
with
self
.
progress_bar
(
total
=
num_inference_steps
)
as
progress_bar
:
if
(
get_pipeline_parallel_world_size
()
>
1
and
len
(
timesteps
)
>
num_pipeline_warmup_steps
):
# * warmup stage
latents
=
self
.
_sync_pipeline
(
latents
=
latents
,
prompt_embeds
=
prompt_embeds
,
pooled_prompt_embeds
=
pooled_prompt_embeds
,
timesteps
=
timesteps
[:
num_pipeline_warmup_steps
],
num_warmup_steps
=
num_warmup_steps
,
progress_bar
=
progress_bar
,
callback_on_step_end
=
callback_on_step_end
,
callback_on_step_end_tensor_inputs
=
callback_on_step_end_tensor_inputs
,
)
# * pipefusion stage
latents
=
self
.
_async_pipeline
(
latents
=
latents
,
prompt_embeds
=
prompt_embeds
,
pooled_prompt_embeds
=
pooled_prompt_embeds
,
timesteps
=
timesteps
[
num_pipeline_warmup_steps
:],
num_warmup_steps
=
num_warmup_steps
,
progress_bar
=
progress_bar
,
callback_on_step_end
=
callback_on_step_end
,
callback_on_step_end_tensor_inputs
=
callback_on_step_end_tensor_inputs
,
)
else
:
latents
=
self
.
_sync_pipeline
(
latents
=
latents
,
prompt_embeds
=
prompt_embeds
,
pooled_prompt_embeds
=
pooled_prompt_embeds
,
timesteps
=
timesteps
,
num_warmup_steps
=
num_warmup_steps
,
progress_bar
=
progress_bar
,
callback_on_step_end
=
callback_on_step_end
,
callback_on_step_end_tensor_inputs
=
callback_on_step_end_tensor_inputs
,
sync_only
=
True
,
)
# * 8. Decode latents (only the last rank in a dp group)
def
vae_decode
(
latents
):
latents
=
(
latents
/
self
.
vae
.
config
.
scaling_factor
)
+
self
.
vae
.
config
.
shift_factor
image
=
self
.
vae
.
decode
(
latents
,
return_dict
=
False
)[
0
]
return
image
if
not
output_type
==
"latent"
:
if
get_runtime_state
().
runtime_config
.
use_parallel_vae
:
latents
=
self
.
gather_broadcast_latents
(
latents
)
image
=
vae_decode
(
latents
)
else
:
if
is_dp_last_group
():
image
=
vae_decode
(
latents
)
if
self
.
is_dp_last_group
():
if
output_type
==
"latent"
:
image
=
latents
else
:
image
=
self
.
image_processor
.
postprocess
(
image
,
output_type
=
output_type
)
# Offload all models
self
.
maybe_free_model_hooks
()
if
not
return_dict
:
return
(
image
,)
return
StableDiffusion3PipelineOutput
(
images
=
image
)
else
:
return
None
def
_init_sync_pipeline
(
self
,
latents
:
torch
.
Tensor
,
prompt_embeds
:
torch
.
Tensor
):
get_runtime_state
().
set_patched_mode
(
patch_mode
=
False
)
latents_list
=
[
latents
[:,
:,
start_idx
:
end_idx
,
:]
for
start_idx
,
end_idx
in
get_runtime_state
().
pp_patches_start_end_idx_global
]
latents
=
torch
.
cat
(
latents_list
,
dim
=-
2
)
if
get_runtime_state
().
split_text_embed_in_sp
:
if
prompt_embeds
.
shape
[
-
2
]
%
get_sequence_parallel_world_size
()
==
0
:
prompt_embeds
=
torch
.
chunk
(
prompt_embeds
,
get_sequence_parallel_world_size
(),
dim
=-
2
)[
get_sequence_parallel_rank
()]
else
:
get_runtime_state
().
split_text_embed_in_sp
=
False
return
latents
,
prompt_embeds
# synchronized compute the whole feature map in each pp stage
def
_sync_pipeline
(
self
,
latents
:
torch
.
Tensor
,
prompt_embeds
:
torch
.
Tensor
,
pooled_prompt_embeds
:
torch
.
Tensor
,
timesteps
:
List
[
int
],
num_warmup_steps
:
int
,
progress_bar
,
callback_on_step_end
:
Optional
[
Callable
[[
int
,
int
,
Dict
],
None
]]
=
None
,
callback_on_step_end_tensor_inputs
:
List
[
str
]
=
[
"latents"
],
sync_only
:
bool
=
False
,
):
latents
,
prompt_embeds
=
self
.
_init_sync_pipeline
(
latents
,
prompt_embeds
)
for
i
,
t
in
enumerate
(
timesteps
):
if
self
.
interrupt
:
continue
if
is_pipeline_last_stage
():
last_timestep_latents
=
latents
# when there is only one pp stage, no need to recv
if
get_pipeline_parallel_world_size
()
==
1
:
pass
# all ranks should recv the latent from the previous rank except
# the first rank in the first pipeline forward which should use
# the input latent
elif
is_pipeline_first_stage
()
and
i
==
0
:
pass
else
:
latents
=
get_pp_group
().
pipeline_recv
()
if
not
is_pipeline_first_stage
():
encoder_hidden_states
=
get_pp_group
().
pipeline_recv
(
0
,
"encoder_hidden_states"
)
latents
,
encoder_hidden_states
=
self
.
_backbone_forward
(
latents
=
latents
,
encoder_hidden_states
=
(
prompt_embeds
if
is_pipeline_first_stage
()
else
encoder_hidden_states
),
pooled_prompt_embeds
=
pooled_prompt_embeds
,
t
=
t
,
)
if
is_pipeline_last_stage
():
latents_dtype
=
latents
.
dtype
latents
=
self
.
_scheduler_step
(
latents
,
last_timestep_latents
,
t
)
if
latents
.
dtype
!=
latents_dtype
:
if
torch
.
backends
.
mps
.
is_available
():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents
=
latents
.
to
(
latents_dtype
)
if
callback_on_step_end
is
not
None
:
callback_kwargs
=
{}
for
k
in
callback_on_step_end_tensor_inputs
:
callback_kwargs
[
k
]
=
locals
()[
k
]
callback_outputs
=
callback_on_step_end
(
self
,
i
,
t
,
callback_kwargs
)
latents
=
callback_outputs
.
pop
(
"latents"
,
latents
)
prompt_embeds
=
callback_outputs
.
pop
(
"prompt_embeds"
,
prompt_embeds
)
negative_prompt_embeds
=
callback_outputs
.
pop
(
"negative_prompt_embeds"
,
negative_prompt_embeds
)
negative_pooled_prompt_embeds
=
callback_outputs
.
pop
(
"negative_pooled_prompt_embeds"
,
negative_pooled_prompt_embeds
)
if
i
==
len
(
timesteps
)
-
1
or
(
(
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
sync_only
and
is_pipeline_last_stage
()
and
i
==
len
(
timesteps
)
-
1
:
pass
elif
get_pipeline_parallel_world_size
()
>
1
:
get_pp_group
().
pipeline_send
(
latents
)
if
not
is_pipeline_last_stage
():
get_pp_group
().
pipeline_send
(
encoder_hidden_states
,
name
=
"encoder_hidden_states"
)
if
(
sync_only
and
get_sequence_parallel_world_size
()
>
1
and
is_pipeline_last_stage
()
):
sp_degree
=
get_sequence_parallel_world_size
()
sp_latents_list
=
get_sp_group
().
all_gather
(
latents
,
separate_tensors
=
True
)
latents_list
=
[]
for
pp_patch_idx
in
range
(
get_runtime_state
().
num_pipeline_patch
):
latents_list
+=
[
sp_latents_list
[
sp_patch_idx
][
:,
:,
get_runtime_state
()
.
pp_patches_start_idx_local
[
pp_patch_idx
]
:
get_runtime_state
()
.
pp_patches_start_idx_local
[
pp_patch_idx
+
1
],
:,
]
for
sp_patch_idx
in
range
(
sp_degree
)
]
latents
=
torch
.
cat
(
latents_list
,
dim
=-
2
)
return
latents
def
_init_async_pipeline
(
self
,
num_timesteps
:
int
,
latents
:
torch
.
Tensor
,
num_pipeline_warmup_steps
:
int
,
):
get_runtime_state
().
set_patched_mode
(
patch_mode
=
True
)
if
is_pipeline_first_stage
():
# get latents computed in warmup stage
# ignore latents after the last timestep
latents
=
(
get_pp_group
().
pipeline_recv
()
if
num_pipeline_warmup_steps
>
0
else
latents
)
patch_latents
=
list
(
latents
.
split
(
get_runtime_state
().
pp_patches_height
,
dim
=
2
)
)
elif
is_pipeline_last_stage
():
patch_latents
=
list
(
latents
.
split
(
get_runtime_state
().
pp_patches_height
,
dim
=
2
)
)
else
:
patch_latents
=
[
None
for
_
in
range
(
get_runtime_state
().
num_pipeline_patch
)
]
recv_timesteps
=
(
num_timesteps
-
1
if
is_pipeline_first_stage
()
else
num_timesteps
)
if
is_pipeline_first_stage
():
for
_
in
range
(
recv_timesteps
):
for
patch_idx
in
range
(
get_runtime_state
().
num_pipeline_patch
):
get_pp_group
().
add_pipeline_recv_task
(
patch_idx
)
else
:
for
_
in
range
(
recv_timesteps
):
get_pp_group
().
add_pipeline_recv_task
(
0
,
"encoder_hidden_states"
)
for
patch_idx
in
range
(
get_runtime_state
().
num_pipeline_patch
):
get_pp_group
().
add_pipeline_recv_task
(
patch_idx
)
return
patch_latents
# * implement of pipefusion
def
_async_pipeline
(
self
,
latents
:
torch
.
Tensor
,
prompt_embeds
:
torch
.
Tensor
,
pooled_prompt_embeds
:
torch
.
Tensor
,
timesteps
:
List
[
int
],
num_warmup_steps
:
int
,
progress_bar
,
callback_on_step_end
:
Optional
[
Callable
[[
int
,
int
,
Dict
],
None
]]
=
None
,
callback_on_step_end_tensor_inputs
:
List
[
str
]
=
[
"latents"
],
):
if
len
(
timesteps
)
==
0
:
return
latents
num_pipeline_patch
=
get_runtime_state
().
num_pipeline_patch
num_pipeline_warmup_steps
=
get_runtime_state
().
runtime_config
.
warmup_steps
patch_latents
=
self
.
_init_async_pipeline
(
num_timesteps
=
len
(
timesteps
),
latents
=
latents
,
num_pipeline_warmup_steps
=
num_pipeline_warmup_steps
,
)
last_patch_latents
=
(
[
None
for
_
in
range
(
num_pipeline_patch
)]
if
(
is_pipeline_last_stage
())
else
None
)
first_async_recv
=
True
for
i
,
t
in
enumerate
(
timesteps
):
if
self
.
interrupt
:
continue
for
patch_idx
in
range
(
num_pipeline_patch
):
if
is_pipeline_last_stage
():
last_patch_latents
[
patch_idx
]
=
patch_latents
[
patch_idx
]
if
is_pipeline_first_stage
()
and
i
==
0
:
pass
else
:
if
first_async_recv
:
if
not
is_pipeline_first_stage
()
and
patch_idx
==
0
:
get_pp_group
().
recv_next
()
get_pp_group
().
recv_next
()
first_async_recv
=
False
if
not
is_pipeline_first_stage
()
and
patch_idx
==
0
:
last_encoder_hidden_states
=
(
get_pp_group
().
get_pipeline_recv_data
(
idx
=
patch_idx
,
name
=
"encoder_hidden_states"
)
)
patch_latents
[
patch_idx
]
=
get_pp_group
().
get_pipeline_recv_data
(
idx
=
patch_idx
)
patch_latents
[
patch_idx
],
next_encoder_hidden_states
=
(
self
.
_backbone_forward
(
latents
=
patch_latents
[
patch_idx
],
encoder_hidden_states
=
(
prompt_embeds
if
is_pipeline_first_stage
()
else
last_encoder_hidden_states
),
pooled_prompt_embeds
=
pooled_prompt_embeds
,
t
=
t
,
)
)
if
is_pipeline_last_stage
():
latents_dtype
=
patch_latents
[
patch_idx
].
dtype
patch_latents
[
patch_idx
]
=
self
.
_scheduler_step
(
patch_latents
[
patch_idx
],
last_patch_latents
[
patch_idx
],
t
,
)
if
latents
.
dtype
!=
latents_dtype
:
if
torch
.
backends
.
mps
.
is_available
():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents
=
latents
.
to
(
latents_dtype
)
if
callback_on_step_end
is
not
None
:
callback_kwargs
=
{}
for
k
in
callback_on_step_end_tensor_inputs
:
callback_kwargs
[
k
]
=
locals
()[
k
]
callback_outputs
=
callback_on_step_end
(
self
,
i
,
t
,
callback_kwargs
)
latents
=
callback_outputs
.
pop
(
"latents"
,
latents
)
prompt_embeds
=
callback_outputs
.
pop
(
"prompt_embeds"
,
prompt_embeds
)
negative_prompt_embeds
=
callback_outputs
.
pop
(
"negative_prompt_embeds"
,
negative_prompt_embeds
)
negative_pooled_prompt_embeds
=
callback_outputs
.
pop
(
"negative_pooled_prompt_embeds"
,
negative_pooled_prompt_embeds
,
)
if
i
!=
len
(
timesteps
)
-
1
:
get_pp_group
().
pipeline_isend
(
patch_latents
[
patch_idx
],
segment_idx
=
patch_idx
)
else
:
if
patch_idx
==
0
:
get_pp_group
().
pipeline_isend
(
next_encoder_hidden_states
,
name
=
"encoder_hidden_states"
)
get_pp_group
().
pipeline_isend
(
patch_latents
[
patch_idx
],
segment_idx
=
patch_idx
)
if
is_pipeline_first_stage
()
and
i
==
0
:
pass
else
:
if
i
==
len
(
timesteps
)
-
1
and
patch_idx
==
num_pipeline_patch
-
1
:
pass
elif
is_pipeline_first_stage
():
get_pp_group
().
recv_next
()
else
:
# recv encoder_hidden_state
if
patch_idx
==
num_pipeline_patch
-
1
:
get_pp_group
().
recv_next
()
# recv latents
get_pp_group
().
recv_next
()
get_runtime_state
().
next_patch
()
if
i
==
len
(
timesteps
)
-
1
or
(
(
i
+
num_pipeline_warmup_steps
+
1
)
>
num_warmup_steps
and
(
i
+
num_pipeline_warmup_steps
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
latents
=
None
if
is_pipeline_last_stage
():
latents
=
torch
.
cat
(
patch_latents
,
dim
=
2
)
if
get_sequence_parallel_world_size
()
>
1
:
sp_degree
=
get_sequence_parallel_world_size
()
sp_latents_list
=
get_sp_group
().
all_gather
(
latents
,
separate_tensors
=
True
)
latents_list
=
[]
for
pp_patch_idx
in
range
(
get_runtime_state
().
num_pipeline_patch
):
latents_list
+=
[
sp_latents_list
[
sp_patch_idx
][
...,
get_runtime_state
()
.
pp_patches_start_idx_local
[
pp_patch_idx
]
:
get_runtime_state
()
.
pp_patches_start_idx_local
[
pp_patch_idx
+
1
],
:,
]
for
sp_patch_idx
in
range
(
sp_degree
)
]
latents
=
torch
.
cat
(
latents_list
,
dim
=-
2
)
return
latents
def
_backbone_forward
(
self
,
latents
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
pooled_prompt_embeds
:
torch
.
Tensor
,
t
:
Union
[
float
,
torch
.
Tensor
],
):
if
is_pipeline_first_stage
():
latents
=
torch
.
cat
(
[
latents
]
*
(
2
//
get_classifier_free_guidance_world_size
())
)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep
=
t
.
expand
(
latents
.
shape
[
0
])
noise_pred
,
encoder_hidden_states
=
self
.
transformer
(
hidden_states
=
latents
,
timestep
=
timestep
,
encoder_hidden_states
=
encoder_hidden_states
,
pooled_projections
=
pooled_prompt_embeds
,
joint_attention_kwargs
=
self
.
joint_attention_kwargs
,
return_dict
=
False
,
)[
0
]
# classifier free guidance
if
is_pipeline_last_stage
():
if
get_classifier_free_guidance_world_size
()
==
1
:
noise_pred_uncond
,
noise_pred_text
=
noise_pred
.
chunk
(
2
)
elif
get_classifier_free_guidance_world_size
()
==
2
:
noise_pred_uncond
,
noise_pred_text
=
get_cfg_group
().
all_gather
(
noise_pred
,
separate_tensors
=
True
)
latents
=
noise_pred_uncond
+
self
.
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
else
:
latents
=
noise_pred
return
latents
,
encoder_hidden_states
def
_scheduler_step
(
self
,
noise_pred
:
torch
.
Tensor
,
latents
:
torch
.
Tensor
,
t
:
Union
[
float
,
torch
.
Tensor
],
):
return
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
return_dict
=
False
,
)[
0
]
xfuser1/model_executor/pipelines/register.py
0 → 100755
View file @
0513d03d
from
typing
import
Dict
,
Type
,
Union
from
diffusers.pipelines.pipeline_utils
import
DiffusionPipeline
from
xfuser.logger
import
init_logger
from
.base_pipeline
import
xFuserPipelineBaseWrapper
logger
=
init_logger
(
__name__
)
class
xFuserPipelineWrapperRegister
:
_XFUSER_PIPE_MAPPING
:
Dict
[
Type
[
DiffusionPipeline
],
Type
[
xFuserPipelineBaseWrapper
]
]
=
{}
@
classmethod
def
register
(
cls
,
origin_pipe_class
:
Type
[
DiffusionPipeline
]):
def
decorator
(
xfuser_pipe_class
:
Type
[
xFuserPipelineBaseWrapper
]):
if
not
issubclass
(
xfuser_pipe_class
,
xFuserPipelineBaseWrapper
):
raise
ValueError
(
f
"
{
xfuser_pipe_class
}
is not a subclass of"
f
" xFuserPipelineBaseWrapper"
)
cls
.
_XFUSER_PIPE_MAPPING
[
origin_pipe_class
]
=
\
xfuser_pipe_class
return
xfuser_pipe_class
return
decorator
@
classmethod
def
get_class
(
cls
,
pipe
:
Union
[
DiffusionPipeline
,
Type
[
DiffusionPipeline
]]
)
->
Type
[
xFuserPipelineBaseWrapper
]:
if
isinstance
(
pipe
,
type
):
candidate
=
None
candidate_origin
=
None
for
(
origin_model_class
,
xfuser_model_class
)
in
cls
.
_XFUSER_PIPE_MAPPING
.
items
():
if
issubclass
(
pipe
,
origin_model_class
):
if
((
candidate
is
None
and
candidate_origin
is
None
)
or
issubclass
(
origin_model_class
,
candidate_origin
)):
candidate_origin
=
origin_model_class
candidate
=
xfuser_model_class
if
candidate
is
None
:
raise
ValueError
(
f
"Diffusion Pipeline class
{
pipe
}
"
f
"is not supported by xFuser"
)
else
:
return
candidate
elif
isinstance
(
pipe
,
DiffusionPipeline
):
candidate
=
None
candidate_origin
=
None
for
(
origin_model_class
,
xfuser_model_class
)
in
cls
.
_XFUSER_PIPE_MAPPING
.
items
():
if
isinstance
(
pipe
,
origin_model_class
):
if
((
candidate
is
None
and
candidate_origin
is
None
)
or
issubclass
(
origin_model_class
,
candidate_origin
)):
candidate_origin
=
origin_model_class
candidate
=
xfuser_model_class
if
candidate
is
None
:
raise
ValueError
(
f
"Diffusion Pipeline class
{
pipe
.
__class__
}
"
f
"is not supported by xFuser"
)
else
:
return
candidate
else
:
raise
ValueError
(
f
"Unsupported type
{
type
(
pipe
)
}
for pipe"
)
\ No newline at end of file
xfuser1/model_executor/schedulers/__init__.py
0 → 100755
View file @
0513d03d
from
.register
import
xFuserSchedulerWrappersRegister
from
.base_scheduler
import
xFuserSchedulerBaseWrapper
from
.scheduling_dpmsolver_multistep
import
(
xFuserDPMSolverMultistepSchedulerWrapper
)
from
.scheduling_flow_match_euler_discrete
import
(
xFuserFlowMatchEulerDiscreteSchedulerWrapper
,
)
from
.scheduling_ddim
import
xFuserDDIMSchedulerWrapper
from
.scheduling_ddpm
import
xFuserDDPMSchedulerWrapper
from
.scheduling_ddim_cogvideox
import
xFuserCogVideoXDDIMSchedulerWrapper
from
.scheduling_dpm_cogvideox
import
xFuserCogVideoXDPMSchedulerWrapper
__all__
=
[
"xFuserSchedulerWrappersRegister"
,
"xFuserSchedulerBaseWrapper"
,
"xFuserDPMSolverMultistepSchedulerWrapper"
,
"xFuserFlowMatchEulerDiscreteSchedulerWrapper"
,
"xFuserDDIMSchedulerWrapper"
,
"xFuserCogVideoXDDIMSchedulerWrapper"
,
"xFuserCogVideoXDPMSchedulerWrapper"
,
"xFuserDDPMSchedulerWrapper"
,
]
\ No newline at end of file
xfuser1/model_executor/schedulers/base_scheduler.py
0 → 100755
View file @
0513d03d
from
abc
import
abstractmethod
,
ABCMeta
from
functools
import
wraps
from
typing
import
List
from
diffusers.schedulers
import
SchedulerMixin
from
xfuser.core.distributed
import
(
get_pipeline_parallel_world_size
,
get_sequence_parallel_world_size
,
)
from
xfuser.model_executor.base_wrapper
import
xFuserBaseWrapper
class
xFuserSchedulerBaseWrapper
(
xFuserBaseWrapper
,
metaclass
=
ABCMeta
):
def
__init__
(
self
,
module
:
SchedulerMixin
,
):
super
().
__init__
(
module
=
module
,
)
def
__setattr__
(
self
,
name
,
value
):
if
name
==
"module"
:
super
().
__setattr__
(
name
,
value
)
elif
(
hasattr
(
self
,
"module"
)
and
self
.
module
is
not
None
and
hasattr
(
self
.
module
,
name
)
):
setattr
(
self
.
module
,
name
,
value
)
else
:
super
().
__setattr__
(
name
,
value
)
@
abstractmethod
def
step
(
self
,
*
args
,
**
kwargs
):
pass
@
staticmethod
def
check_to_use_naive_step
(
func
):
@
wraps
(
func
)
def
check_naive_step_fn
(
self
,
*
args
,
**
kwargs
):
if
(
get_pipeline_parallel_world_size
()
==
1
and
get_sequence_parallel_world_size
()
==
1
):
return
self
.
module
.
step
(
*
args
,
**
kwargs
)
else
:
return
func
(
self
,
*
args
,
**
kwargs
)
return
check_naive_step_fn
xfuser1/model_executor/schedulers/register.py
0 → 100755
View file @
0513d03d
from
typing
import
Dict
,
Type
import
torch
import
torch.nn
as
nn
from
xfuser.logger
import
init_logger
from
xfuser.model_executor.schedulers.base_scheduler
import
xFuserSchedulerBaseWrapper
logger
=
init_logger
(
__name__
)
class
xFuserSchedulerWrappersRegister
:
_XFUSER_SCHEDULER_MAPPING
:
Dict
[
Type
[
nn
.
Module
],
Type
[
xFuserSchedulerBaseWrapper
]
]
=
{}
@
classmethod
def
register
(
cls
,
origin_scheduler_class
:
Type
[
nn
.
Module
]):
def
decorator
(
xfuser_scheduler_class
:
Type
[
nn
.
Module
]):
if
not
issubclass
(
xfuser_scheduler_class
,
xFuserSchedulerBaseWrapper
):
raise
ValueError
(
f
"
{
xfuser_scheduler_class
.
__class__
.
__name__
}
is not "
f
"a subclass of xFuserSchedulerBaseWrapper"
)
cls
.
_XFUSER_SCHEDULER_MAPPING
[
origin_scheduler_class
]
=
\
xfuser_scheduler_class
return
xfuser_scheduler_class
return
decorator
@
classmethod
def
get_wrapper
(
cls
,
scheduler
:
nn
.
Module
)
->
xFuserSchedulerBaseWrapper
:
candidate
=
None
candidate_origin
=
None
for
(
origin_scheduler_class
,
wrapper_class
)
in
cls
.
_XFUSER_SCHEDULER_MAPPING
.
items
():
if
isinstance
(
scheduler
,
origin_scheduler_class
):
if
((
candidate
is
None
and
candidate_origin
is
None
)
or
origin_scheduler_class
==
scheduler
.
__class__
or
issubclass
(
origin_scheduler_class
,
candidate_origin
)):
candidate_origin
=
origin_scheduler_class
candidate
=
wrapper_class
if
candidate
is
None
:
raise
ValueError
(
f
"Scheduler class
{
scheduler
.
__class__
.
__name__
}
"
f
"is not supported by xFuser"
)
else
:
return
candidate
\ No newline at end of file
xfuser1/model_executor/schedulers/scheduling_ddim.py
0 → 100755
View file @
0513d03d
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
from
diffusers.utils.torch_utils
import
randn_tensor
from
diffusers.schedulers.scheduling_ddim
import
(
DDIMScheduler
,
DDIMSchedulerOutput
,
)
from
xfuser.core.distributed
import
(
get_pipeline_parallel_world_size
,
get_sequence_parallel_world_size
,
get_runtime_state
,
)
from
.register
import
xFuserSchedulerWrappersRegister
from
.base_scheduler
import
xFuserSchedulerBaseWrapper
@
xFuserSchedulerWrappersRegister
.
register
(
DDIMScheduler
)
class
xFuserDDIMSchedulerWrapper
(
xFuserSchedulerBaseWrapper
):
@
xFuserSchedulerBaseWrapper
.
check_to_use_naive_step
def
step
(
self
,
*
args
,
**
kwargs
,
)
->
Union
[
DDIMSchedulerOutput
,
Tuple
]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
The weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
return
self
.
module
.
step
(
*
args
,
**
kwargs
)
xfuser1/model_executor/schedulers/scheduling_ddim_cogvideox.py
0 → 100755
View file @
0513d03d
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
from
diffusers.utils.torch_utils
import
randn_tensor
from
diffusers.schedulers.scheduling_ddim_cogvideox
import
(
CogVideoXDDIMScheduler
,
DDIMSchedulerOutput
,
)
from
.register
import
xFuserSchedulerWrappersRegister
from
.base_scheduler
import
xFuserSchedulerBaseWrapper
@
xFuserSchedulerWrappersRegister
.
register
(
CogVideoXDDIMScheduler
)
class
xFuserCogVideoXDDIMSchedulerWrapper
(
xFuserSchedulerBaseWrapper
):
@
xFuserSchedulerBaseWrapper
.
check_to_use_naive_step
def
step
(
self
,
*
args
,
**
kwargs
,
)
->
Union
[
DDIMSchedulerOutput
,
Tuple
]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
The weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
return
self
.
module
.
step
(
*
args
,
**
kwargs
)
xfuser1/model_executor/schedulers/scheduling_ddpm.py
0 → 100755
View file @
0513d03d
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
from
diffusers.utils.torch_utils
import
randn_tensor
from
diffusers.schedulers.scheduling_ddpm
import
(
DDPMScheduler
,
DDPMSchedulerOutput
,
)
from
xfuser.core.distributed
import
(
get_pipeline_parallel_world_size
,
get_sequence_parallel_world_size
,
get_runtime_state
,
)
from
.register
import
xFuserSchedulerWrappersRegister
from
.base_scheduler
import
xFuserSchedulerBaseWrapper
@
xFuserSchedulerWrappersRegister
.
register
(
DDPMScheduler
)
class
xFuserDDPMSchedulerWrapper
(
xFuserSchedulerBaseWrapper
):
@
xFuserSchedulerBaseWrapper
.
check_to_use_naive_step
def
step
(
self
,
*
args
,
generator
=
None
,
**
kwargs
,
)
->
Union
[
DDPMSchedulerOutput
,
Tuple
]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
return
self
.
module
.
step
(
*
args
,
generator
,
**
kwargs
)
xfuser1/model_executor/schedulers/scheduling_dpm_cogvideox.py
0 → 100755
View file @
0513d03d
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
from
diffusers.utils.torch_utils
import
randn_tensor
from
diffusers.schedulers.scheduling_dpm_cogvideox
import
(
CogVideoXDPMScheduler
,
DDIMSchedulerOutput
,
)
from
.register
import
xFuserSchedulerWrappersRegister
from
.base_scheduler
import
xFuserSchedulerBaseWrapper
@
xFuserSchedulerWrappersRegister
.
register
(
CogVideoXDPMScheduler
)
class
xFuserCogVideoXDPMSchedulerWrapper
(
xFuserSchedulerBaseWrapper
):
@
xFuserSchedulerBaseWrapper
.
check_to_use_naive_step
def
step
(
self
,
*
args
,
**
kwargs
,
)
->
Union
[
DDIMSchedulerOutput
,
Tuple
]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
The weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
return
self
.
module
.
step
(
*
args
,
**
kwargs
)
xfuser1/model_executor/schedulers/scheduling_dpmsolver_multistep.py
0 → 100755
View file @
0513d03d
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
from
diffusers.utils.torch_utils
import
randn_tensor
from
diffusers.schedulers.scheduling_dpmsolver_multistep
import
(
DPMSolverMultistepScheduler
,
SchedulerOutput
,
)
from
xfuser.core.distributed
import
get_runtime_state
from
.register
import
xFuserSchedulerWrappersRegister
from
.base_scheduler
import
xFuserSchedulerBaseWrapper
@
xFuserSchedulerWrappersRegister
.
register
(
DPMSolverMultistepScheduler
)
class
xFuserDPMSolverMultistepSchedulerWrapper
(
xFuserSchedulerBaseWrapper
):
@
xFuserSchedulerBaseWrapper
.
check_to_use_naive_step
def
step
(
self
,
model_output
:
torch
.
Tensor
,
timestep
:
int
,
sample
:
torch
.
Tensor
,
generator
=
None
,
variance_noise
:
Optional
[
torch
.
Tensor
]
=
None
,
return_dict
:
bool
=
True
,
)
->
Union
[
SchedulerOutput
,
Tuple
]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
the multistep DPMSolver.
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`LEdits++`].
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if
self
.
num_inference_steps
is
None
:
raise
ValueError
(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
# Improve numerical stability for small number of steps
lower_order_final
=
(
self
.
step_index
==
len
(
self
.
timesteps
)
-
1
)
and
(
self
.
config
.
euler_at_final
or
(
self
.
config
.
lower_order_final
and
len
(
self
.
timesteps
)
<
15
)
or
self
.
config
.
final_sigmas_type
==
"zero"
)
lower_order_second
=
(
(
self
.
step_index
==
len
(
self
.
timesteps
)
-
2
)
and
self
.
config
.
lower_order_final
and
len
(
self
.
timesteps
)
<
15
)
model_output
=
self
.
convert_model_output
(
model_output
,
sample
=
sample
)
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
if
(
get_runtime_state
().
patch_mode
and
get_runtime_state
().
pipeline_patch_idx
==
0
and
self
.
model_outputs
[
-
1
]
is
None
):
self
.
model_outputs
[
-
1
]
=
torch
.
zeros
(
[
model_output
.
shape
[
0
],
model_output
.
shape
[
1
],
get_runtime_state
().
pp_patches_start_idx_local
[
-
1
],
model_output
.
shape
[
3
],
],
device
=
model_output
.
device
,
dtype
=
model_output
.
dtype
,
)
if
get_runtime_state
().
pipeline_patch_idx
==
0
:
for
i
in
range
(
self
.
config
.
solver_order
-
1
):
self
.
model_outputs
[
i
]
=
self
.
model_outputs
[
i
+
1
]
if
(
get_runtime_state
().
patch_mode
and
get_runtime_state
().
pipeline_patch_idx
==
0
):
assert
len
(
self
.
model_outputs
)
>=
2
self
.
model_outputs
[
-
1
]
=
torch
.
zeros_like
(
self
.
model_outputs
[
-
2
])
if
get_runtime_state
().
patch_mode
:
self
.
model_outputs
[
-
1
][
:,
:,
get_runtime_state
()
.
pp_patches_start_idx_local
[
get_runtime_state
().
pipeline_patch_idx
]
:
get_runtime_state
()
.
pp_patches_start_idx_local
[
get_runtime_state
().
pipeline_patch_idx
+
1
],
:,
]
=
model_output
else
:
self
.
model_outputs
[
-
1
]
=
model_output
#! ORIGIN:
# for i in range(self.config.solver_order - 1):
# self.model_outputs[i] = self.model_outputs[i + 1]
# self.model_outputs[-1] = model_output
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
# Upcast to avoid precision issues when computing prev_sample
sample
=
sample
.
to
(
torch
.
float32
)
if
(
self
.
config
.
algorithm_type
in
[
"sde-dpmsolver"
,
"sde-dpmsolver++"
]
and
variance_noise
is
None
):
noise
=
randn_tensor
(
model_output
.
shape
,
generator
=
generator
,
device
=
model_output
.
device
,
dtype
=
torch
.
float32
,
)
elif
self
.
config
.
algorithm_type
in
[
"sde-dpmsolver"
,
"sde-dpmsolver++"
]:
noise
=
variance_noise
.
to
(
device
=
model_output
.
device
,
dtype
=
torch
.
float32
)
else
:
noise
=
None
#! ---------------------------------------- ADD BELOW ----------------------------------------
if
get_runtime_state
().
patch_mode
:
model_outputs
=
[]
for
output
in
self
.
model_outputs
:
model_outputs
.
append
(
output
[
:,
:,
get_runtime_state
()
.
pp_patches_start_idx_local
[
get_runtime_state
().
pipeline_patch_idx
]
:
get_runtime_state
()
.
pp_patches_start_idx_local
[
get_runtime_state
().
pipeline_patch_idx
+
1
],
:,
]
)
else
:
model_outputs
=
self
.
model_outputs
#! ---------------------------------------- ADD ABOVE ----------------------------------------
if
(
self
.
config
.
solver_order
==
1
or
self
.
lower_order_nums
<
1
or
lower_order_final
):
prev_sample
=
self
.
dpm_solver_first_order_update
(
model_output
,
sample
=
sample
,
noise
=
noise
)
elif
(
self
.
config
.
solver_order
==
2
or
self
.
lower_order_nums
<
2
or
lower_order_second
):
prev_sample
=
self
.
multistep_dpm_solver_second_order_update
(
model_outputs
,
sample
=
sample
,
noise
=
noise
)
else
:
prev_sample
=
self
.
multistep_dpm_solver_third_order_update
(
model_outputs
,
sample
=
sample
)
if
self
.
lower_order_nums
<
self
.
config
.
solver_order
:
self
.
lower_order_nums
+=
1
# Cast sample back to expected dtype
prev_sample
=
prev_sample
.
to
(
model_output
.
dtype
)
# upon completion increase step index by one
# * increase step index only when the last pipeline patch is done (or not in patch mode)
if
(
not
get_runtime_state
().
patch_mode
or
get_runtime_state
().
pipeline_patch_idx
==
get_runtime_state
().
num_pipeline_patch
-
1
):
self
.
_step_index
+=
1
if
not
return_dict
:
return
(
prev_sample
,)
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
xfuser1/model_executor/schedulers/scheduling_flow_match_euler_discrete.py
0 → 100755
View file @
0513d03d
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
from
diffusers.utils.torch_utils
import
randn_tensor
from
diffusers.schedulers
import
FlowMatchEulerDiscreteScheduler
from
diffusers.schedulers.scheduling_flow_match_euler_discrete
import
(
FlowMatchEulerDiscreteSchedulerOutput
,
)
from
xfuser.core.distributed
import
get_runtime_state
from
.register
import
xFuserSchedulerWrappersRegister
from
.base_scheduler
import
xFuserSchedulerBaseWrapper
@
xFuserSchedulerWrappersRegister
.
register
(
FlowMatchEulerDiscreteScheduler
)
class
xFuserFlowMatchEulerDiscreteSchedulerWrapper
(
xFuserSchedulerBaseWrapper
):
@
xFuserSchedulerBaseWrapper
.
check_to_use_naive_step
def
step
(
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
],
sample
:
torch
.
FloatTensor
,
s_churn
:
float
=
0.0
,
s_tmin
:
float
=
0.0
,
s_tmax
:
float
=
float
(
"inf"
),
s_noise
:
float
=
1.0
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
return_dict
:
bool
=
True
,
)
->
Union
[
FlowMatchEulerDiscreteSchedulerOutput
,
Tuple
]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if
(
isinstance
(
timestep
,
int
)
or
isinstance
(
timestep
,
torch
.
IntTensor
)
or
isinstance
(
timestep
,
torch
.
LongTensor
)
):
raise
ValueError
(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
# Upcast to avoid precision issues when computing prev_sample
sample
=
sample
.
to
(
torch
.
float32
)
sigma
=
self
.
sigmas
[
self
.
step_index
]
gamma
=
(
min
(
s_churn
/
(
len
(
self
.
sigmas
)
-
1
),
2
**
0.5
-
1
)
if
s_tmin
<=
sigma
<=
s_tmax
else
0.0
)
noise
=
randn_tensor
(
model_output
.
shape
,
dtype
=
model_output
.
dtype
,
device
=
model_output
.
device
,
generator
=
generator
,
)
eps
=
noise
*
s_noise
sigma_hat
=
sigma
*
(
gamma
+
1
)
if
gamma
>
0
:
sample
=
sample
+
eps
*
(
sigma_hat
**
2
-
sigma
**
2
)
**
0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
# backwards compatibility
# if self.config.prediction_type == "vector_field":
denoised
=
sample
-
model_output
*
sigma
# 2. Convert to an ODE derivative
derivative
=
(
sample
-
denoised
)
/
sigma_hat
dt
=
self
.
sigmas
[
self
.
step_index
+
1
]
-
sigma_hat
prev_sample
=
sample
+
derivative
*
dt
# Cast sample back to model compatible dtype
prev_sample
=
prev_sample
.
to
(
model_output
.
dtype
)
# upon completion increase step index by one
if
(
not
get_runtime_state
().
patch_mode
or
get_runtime_state
().
pipeline_patch_idx
==
get_runtime_state
().
num_pipeline_patch
-
1
):
self
.
_step_index
+=
1
if
not
return_dict
:
return
(
prev_sample
,)
return
FlowMatchEulerDiscreteSchedulerOutput
(
prev_sample
=
prev_sample
)
xfuser1/parallel.py
0 → 100755
View file @
0513d03d
import
os
from
pathlib
import
Path
from
xfuser.config.config
import
InputConfig
from
xfuser.core.distributed
import
(
init_distributed_environment
,
initialize_model_parallel
,
)
from
xfuser.config
import
EngineConfig
from
xfuser.core.distributed.parallel_state
import
(
get_data_parallel_rank
,
get_data_parallel_world_size
,
is_dp_last_group
,
)
from
xfuser.core.distributed.runtime_state
import
get_runtime_state
from
xfuser.logger
import
init_logger
from
xfuser.model_executor.pipelines.base_pipeline
import
xFuserPipelineBaseWrapper
from
xfuser.model_executor.pipelines.register
import
xFuserPipelineWrapperRegister
logger
=
init_logger
(
__name__
)
class
xDiTParallel
:
def
__init__
(
self
,
pipe
,
engine_config
:
EngineConfig
,
input_config
:
InputConfig
):
xfuser_pipe_wrapper
=
xFuserPipelineWrapperRegister
.
get_class
(
pipe
)
self
.
pipe
=
xfuser_pipe_wrapper
(
pipeline
=
pipe
,
engine_config
=
engine_config
)
self
.
config
=
engine_config
self
.
pipe
.
prepare_run
(
input_config
)
def
__call__
(
self
,
*
args
,
**
kwargs
,
):
self
.
result
=
self
.
pipe
(
*
args
,
**
kwargs
)
return
self
.
result
def
save
(
self
,
directory
:
str
,
prefix
:
str
):
dp_rank
=
get_data_parallel_rank
()
parallel_info
=
(
f
"dp
{
self
.
config
.
parallel_config
.
dp_degree
}
_cfg
{
self
.
config
.
parallel_config
.
cfg_degree
}
_"
f
"ulysses
{
self
.
config
.
parallel_config
.
ulysses_degree
}
_ring
{
self
.
config
.
parallel_config
.
ring_degree
}
_"
f
"pp
{
self
.
config
.
parallel_config
.
pp_degree
}
_patch
{
self
.
config
.
parallel_config
.
pp_config
.
num_pipeline_patch
}
"
)
if
is_dp_last_group
():
path
=
Path
(
f
"
{
directory
}
"
)
path
.
mkdir
(
mode
=
755
,
parents
=
True
,
exist_ok
=
True
)
path
=
path
/
f
"
{
prefix
}
_result_
{
parallel_info
}
_dprank
{
dp_rank
}
"
for
i
,
image
in
enumerate
(
self
.
result
.
images
):
image
.
save
(
f
"
{
str
(
path
)
}
_image
{
i
}
.png"
)
print
(
f
"
{
str
(
path
)
}
_image
{
i
}
.png"
)
def
__del__
(
self
):
get_runtime_state
().
destory_distributed_env
()
Prev
1
…
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment