Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
TS-MODELS-OPT
training
Video-Generation-Model
Commits
c07946d8
Commit
c07946d8
authored
Apr 09, 2026
by
hepj
Browse files
dit & video
parents
Changes
270
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
4108 additions
and
0 deletions
+4108
-0
FastVideo-main/fastvideo/v1/models/schedulers/scheduling_unipc_multistep.py
...tvideo/v1/models/schedulers/scheduling_unipc_multistep.py
+1162
-0
FastVideo-main/fastvideo/v1/models/utils.py
FastVideo-main/fastvideo/v1/models/utils.py
+120
-0
FastVideo-main/fastvideo/v1/models/vaes/common.py
FastVideo-main/fastvideo/v1/models/vaes/common.py
+528
-0
FastVideo-main/fastvideo/v1/models/vaes/hunyuanvae.py
FastVideo-main/fastvideo/v1/models/vaes/hunyuanvae.py
+850
-0
FastVideo-main/fastvideo/v1/models/vaes/wanvae.py
FastVideo-main/fastvideo/v1/models/vaes/wanvae.py
+962
-0
FastVideo-main/fastvideo/v1/models/vision_utils.py
FastVideo-main/fastvideo/v1/models/vision_utils.py
+220
-0
FastVideo-main/fastvideo/v1/pipelines/README.md
FastVideo-main/fastvideo/v1/pipelines/README.md
+3
-0
FastVideo-main/fastvideo/v1/pipelines/__init__.py
FastVideo-main/fastvideo/v1/pipelines/__init__.py
+57
-0
FastVideo-main/fastvideo/v1/pipelines/composed_pipeline_base.py
...deo-main/fastvideo/v1/pipelines/composed_pipeline_base.py
+206
-0
FastVideo-main/fastvideo/v1/pipelines/hunyuan/__init__.py
FastVideo-main/fastvideo/v1/pipelines/hunyuan/__init__.py
+0
-0
No files found.
Too many changes to show.
To preserve performance only
270 of 270+
files are displayed.
Plain diff
Email patch
FastVideo-main/fastvideo/v1/models/schedulers/scheduling_unipc_multistep.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
# ==============================================================================
#
# Modified from diffusers==0.33.0.dev0
#
# ==============================================================================
import
math
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.schedulers.scheduling_utils
import
(
KarrasDiffusionSchedulers
,
SchedulerMixin
,
SchedulerOutput
)
from
diffusers.utils
import
deprecate
,
is_scipy_available
from
fastvideo.v1.models.schedulers.base
import
BaseScheduler
if
is_scipy_available
():
import
scipy.stats
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
,
alpha_transform_type
=
"cosine"
,
)
->
torch
.
Tensor
:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if
alpha_transform_type
==
"cosine"
:
def
alpha_bar_fn
(
t
:
float
)
->
float
:
return
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
elif
alpha_transform_type
==
"exp"
:
def
alpha_bar_fn
(
t
:
float
)
->
float
:
return
math
.
exp
(
t
*
-
12.0
)
else
:
raise
ValueError
(
f
"Unsupported alpha_transform_type:
{
alpha_transform_type
}
"
)
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar_fn
(
t2
)
/
alpha_bar_fn
(
t1
),
max_beta
))
return
torch
.
tensor
(
betas
,
dtype
=
torch
.
float32
)
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def
rescale_zero_terminal_snr
(
betas
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
dim
=
0
)
alphas_bar_sqrt
=
alphas_cumprod
.
sqrt
()
# Store old values.
alphas_bar_sqrt_0
=
alphas_bar_sqrt
[
0
].
clone
()
alphas_bar_sqrt_T
=
alphas_bar_sqrt
[
-
1
].
clone
()
# Shift so the last timestep is zero.
alphas_bar_sqrt
-=
alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt
*=
alphas_bar_sqrt_0
/
(
alphas_bar_sqrt_0
-
alphas_bar_sqrt_T
)
# Convert alphas_bar_sqrt to betas
alphas_bar
=
alphas_bar_sqrt
**
2
# Revert sqrt
alphas
=
alphas_bar
[
1
:]
/
alphas_bar
[:
-
1
]
# Revert cumprod
alphas
=
torch
.
cat
([
alphas_bar
[
0
:
1
],
alphas
])
betas
=
1
-
alphas
return
betas
class
UniPCMultistepScheduler
(
SchedulerMixin
,
ConfigMixin
,
BaseScheduler
):
"""
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
solver_order (`int`, default `2`):
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
unconditional sampling.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
dynamic_thresholding_ratio (`float`, defaults to 0.995):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
predict_x0 (`bool`, defaults to `True`):
Whether to use the updating algorithm on the predicted x0.
solver_type (`str`, default `bh2`):
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
otherwise.
lower_order_final (`bool`, default `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
disable_corrector (`list`, default `[]`):
Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
usually disabled during the first few steps.
solver_p (`SchedulerMixin`, default `None`):
Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
"""
_compatibles
=
[
e
.
name
for
e
in
KarrasDiffusionSchedulers
]
order
=
1
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
:
int
=
1000
,
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]]
=
None
,
solver_order
:
int
=
2
,
prediction_type
:
str
=
"epsilon"
,
thresholding
:
bool
=
False
,
dynamic_thresholding_ratio
:
float
=
0.995
,
sample_max_value
:
float
=
1.0
,
predict_x0
:
bool
=
True
,
solver_type
:
str
=
"bh2"
,
lower_order_final
:
bool
=
True
,
disable_corrector
:
Tuple
[
int
,
...]
=
(),
solver_p
:
SchedulerMixin
=
None
,
use_karras_sigmas
:
Optional
[
bool
]
=
False
,
use_exponential_sigmas
:
Optional
[
bool
]
=
False
,
use_beta_sigmas
:
Optional
[
bool
]
=
False
,
use_flow_sigmas
:
Optional
[
bool
]
=
False
,
flow_shift
:
Optional
[
float
]
=
1.0
,
timestep_spacing
:
str
=
"linspace"
,
steps_offset
:
int
=
0
,
final_sigmas_type
:
Optional
[
str
]
=
"zero"
,
# "zero", "sigma_min"
rescale_betas_zero_snr
:
bool
=
False
,
):
if
self
.
config
.
use_beta_sigmas
and
not
is_scipy_available
():
raise
ImportError
(
"Make sure to install scipy if you want to use beta sigmas."
)
if
sum
([
self
.
config
.
use_beta_sigmas
,
self
.
config
.
use_exponential_sigmas
,
self
.
config
.
use_karras_sigmas
])
>
1
:
raise
ValueError
(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
# this schedule is very specific to the latent diffusion model.
self
.
betas
=
torch
.
linspace
(
beta_start
**
0.5
,
beta_end
**
0.5
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
**
2
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# Glide cosine schedule
self
.
betas
=
betas_for_alpha_bar
(
num_train_timesteps
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
is not implemented for
{
self
.
__class__
}
"
)
if
rescale_betas_zero_snr
:
self
.
betas
=
rescale_zero_terminal_snr
(
self
.
betas
)
self
.
alphas
=
1.0
-
self
.
betas
self
.
alphas_cumprod
=
torch
.
cumprod
(
self
.
alphas
,
dim
=
0
)
if
rescale_betas_zero_snr
:
# Close to 0 without being 0 so first sigma is not inf
# FP16 smallest positive subnormal works well here
self
.
alphas_cumprod
[
-
1
]
=
2
**-
24
# Currently we only support VP-type noise schedule
self
.
alpha_t
=
torch
.
sqrt
(
self
.
alphas_cumprod
)
self
.
sigma_t
=
torch
.
sqrt
(
1
-
self
.
alphas_cumprod
)
self
.
lambda_t
=
torch
.
log
(
self
.
alpha_t
)
-
torch
.
log
(
self
.
sigma_t
)
self
.
sigmas
=
((
1
-
self
.
alphas_cumprod
)
/
self
.
alphas_cumprod
)
**
0.5
# standard deviation of the initial noise distribution
self
.
init_noise_sigma
=
1.0
if
solver_type
not
in
[
"bh1"
,
"bh2"
]:
if
solver_type
in
[
"midpoint"
,
"heun"
,
"logrho"
]:
self
.
register_to_config
(
solver_type
=
"bh2"
)
else
:
raise
NotImplementedError
(
f
"
{
solver_type
}
is not implemented for
{
self
.
__class__
}
"
)
self
.
predict_x0
=
predict_x0
# setable values
self
.
num_inference_steps
:
Optional
[
int
]
=
None
timesteps
=
np
.
linspace
(
0
,
num_train_timesteps
-
1
,
num_train_timesteps
,
dtype
=
np
.
float32
)[::
-
1
].
copy
()
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
)
self
.
model_outputs
=
[
None
]
*
solver_order
self
.
timestep_list
:
List
[
Union
[
int
,
torch
.
Tensor
]]
=
[
None
]
*
solver_order
self
.
lower_order_nums
=
0
self
.
disable_corrector
=
list
(
disable_corrector
)
self
.
solver_p
=
solver_p
self
.
last_sample
=
None
self
.
_step_index
:
Optional
[
int
]
=
None
self
.
_begin_index
:
Optional
[
int
]
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
BaseScheduler
.
__init__
(
self
)
@
property
def
step_index
(
self
):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
def
set_shift
(
self
,
shift
:
float
)
->
None
:
self
.
config
.
flow_shift
=
shift
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if
self
.
config
.
timestep_spacing
==
"linspace"
:
timesteps
=
(
np
.
linspace
(
0
,
self
.
config
.
num_train_timesteps
-
1
,
num_inference_steps
+
1
).
round
()[::
-
1
][:
-
1
].
copy
().
astype
(
np
.
int64
))
elif
self
.
config
.
timestep_spacing
==
"leading"
:
step_ratio
=
self
.
config
.
num_train_timesteps
//
(
num_inference_steps
+
1
)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps
=
(
np
.
arange
(
0
,
num_inference_steps
+
1
)
*
step_ratio
).
round
()[::
-
1
][:
-
1
].
copy
().
astype
(
np
.
int64
)
timesteps
+=
self
.
config
.
steps_offset
elif
self
.
config
.
timestep_spacing
==
"trailing"
:
step_ratio
=
self
.
config
.
num_train_timesteps
/
num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps
=
np
.
arange
(
self
.
config
.
num_train_timesteps
,
0
,
-
step_ratio
).
round
().
copy
().
astype
(
np
.
int64
)
timesteps
-=
1
else
:
raise
ValueError
(
f
"
{
self
.
config
.
timestep_spacing
}
is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas
=
np
.
array
(
((
1
-
self
.
alphas_cumprod
)
/
self
.
alphas_cumprod
)
**
0.5
)
if
self
.
config
.
use_karras_sigmas
:
log_sigmas
=
np
.
log
(
sigmas
)
sigmas
=
np
.
flip
(
sigmas
).
copy
()
sigmas
=
self
.
_convert_to_karras
(
in_sigmas
=
sigmas
,
num_inference_steps
=
num_inference_steps
)
timesteps
=
np
.
array
([
self
.
_sigma_to_t
(
sigma
,
log_sigmas
)
for
sigma
in
sigmas
]).
round
()
if
self
.
config
.
final_sigmas_type
==
"sigma_min"
:
sigma_last
=
sigmas
[
-
1
]
elif
self
.
config
.
final_sigmas_type
==
"zero"
:
sigma_last
=
0
else
:
raise
ValueError
(
f
"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got
{
self
.
config
.
final_sigmas_type
}
"
)
sigmas
=
np
.
concatenate
([
sigmas
,
[
sigma_last
]]).
astype
(
np
.
float32
)
elif
self
.
config
.
use_exponential_sigmas
:
log_sigmas
=
np
.
log
(
sigmas
)
sigmas
=
np
.
flip
(
sigmas
).
copy
()
sigmas
=
self
.
_convert_to_exponential
(
in_sigmas
=
sigmas
,
num_inference_steps
=
num_inference_steps
)
timesteps
=
np
.
array
(
[
self
.
_sigma_to_t
(
sigma
,
log_sigmas
)
for
sigma
in
sigmas
])
if
self
.
config
.
final_sigmas_type
==
"sigma_min"
:
sigma_last
=
sigmas
[
-
1
]
elif
self
.
config
.
final_sigmas_type
==
"zero"
:
sigma_last
=
0
else
:
raise
ValueError
(
f
"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got
{
self
.
config
.
final_sigmas_type
}
"
)
sigmas
=
np
.
concatenate
([
sigmas
,
[
sigma_last
]]).
astype
(
np
.
float32
)
elif
self
.
config
.
use_beta_sigmas
:
log_sigmas
=
np
.
log
(
sigmas
)
sigmas
=
np
.
flip
(
sigmas
).
copy
()
sigmas
=
self
.
_convert_to_beta
(
in_sigmas
=
sigmas
,
num_inference_steps
=
num_inference_steps
)
timesteps
=
np
.
array
(
[
self
.
_sigma_to_t
(
sigma
,
log_sigmas
)
for
sigma
in
sigmas
])
if
self
.
config
.
final_sigmas_type
==
"sigma_min"
:
sigma_last
=
sigmas
[
-
1
]
elif
self
.
config
.
final_sigmas_type
==
"zero"
:
sigma_last
=
0
else
:
raise
ValueError
(
f
"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got
{
self
.
config
.
final_sigmas_type
}
"
)
sigmas
=
np
.
concatenate
([
sigmas
,
[
sigma_last
]]).
astype
(
np
.
float32
)
elif
self
.
config
.
use_flow_sigmas
:
alphas
=
np
.
linspace
(
1
,
1
/
self
.
config
.
num_train_timesteps
,
num_inference_steps
+
1
)
sigmas
=
1.0
-
alphas
sigmas
=
np
.
flip
(
self
.
config
.
flow_shift
*
sigmas
/
(
1
+
(
self
.
config
.
flow_shift
-
1
)
*
sigmas
))[:
-
1
].
copy
()
timesteps
=
(
sigmas
*
self
.
config
.
num_train_timesteps
).
copy
()
if
self
.
config
.
final_sigmas_type
==
"sigma_min"
:
sigma_last
=
sigmas
[
-
1
]
elif
self
.
config
.
final_sigmas_type
==
"zero"
:
sigma_last
=
0
else
:
raise
ValueError
(
f
"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got
{
self
.
config
.
final_sigmas_type
}
"
)
sigmas
=
np
.
concatenate
([
sigmas
,
[
sigma_last
]]).
astype
(
np
.
float32
)
else
:
sigmas
=
np
.
interp
(
timesteps
,
np
.
arange
(
0
,
len
(
sigmas
)),
sigmas
)
if
self
.
config
.
final_sigmas_type
==
"sigma_min"
:
sigma_last
=
((
1
-
self
.
alphas_cumprod
[
0
])
/
self
.
alphas_cumprod
[
0
])
**
0.5
elif
self
.
config
.
final_sigmas_type
==
"zero"
:
sigma_last
=
0
else
:
raise
ValueError
(
f
"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got
{
self
.
config
.
final_sigmas_type
}
"
)
sigmas
=
np
.
concatenate
([
sigmas
,
[
sigma_last
]]).
astype
(
np
.
float32
)
self
.
sigmas
=
torch
.
from_numpy
(
sigmas
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
,
dtype
=
torch
.
int64
)
self
.
num_inference_steps
=
len
(
timesteps
)
self
.
model_outputs
=
[
None
,
]
*
self
.
config
.
solver_order
self
.
lower_order_nums
=
0
self
.
last_sample
=
None
if
self
.
solver_p
:
self
.
solver_p
.
set_timesteps
(
self
.
num_inference_steps
,
device
=
device
)
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype
=
sample
.
dtype
batch_size
,
channels
,
*
remaining_dims
=
sample
.
shape
if
dtype
not
in
(
torch
.
float32
,
torch
.
float64
):
sample
=
sample
.
float
(
)
# upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample
=
sample
.
reshape
(
batch_size
,
channels
*
np
.
prod
(
remaining_dims
))
abs_sample
=
sample
.
abs
()
# "a certain percentile absolute pixel value"
s
=
torch
.
quantile
(
abs_sample
,
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
s
=
torch
.
clamp
(
s
,
min
=
1
,
max
=
self
.
config
.
sample_max_value
)
# When clamped to min=1, equivalent to standard clipping to [-1, 1]
s
=
s
.
unsqueeze
(
1
)
# (batch_size, 1) because clamp will broadcast along dim=0
sample
=
torch
.
clamp
(
sample
,
-
s
,
s
)
/
s
# "we threshold xt0 to the range [-s, s] and then divide by s"
sample
=
sample
.
reshape
(
batch_size
,
channels
,
*
remaining_dims
)
sample
=
sample
.
to
(
dtype
)
return
sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
)
->
torch
.
Tensor
:
# get log sigma
log_sigma
=
np
.
log
(
np
.
maximum
(
sigma
,
1e-10
))
# get distribution
dists
=
log_sigma
-
log_sigmas
[:,
np
.
newaxis
]
# get sigmas range
low_idx
=
np
.
cumsum
(
(
dists
>=
0
),
axis
=
0
).
argmax
(
axis
=
0
).
clip
(
max
=
log_sigmas
.
shape
[
0
]
-
2
)
high_idx
=
low_idx
+
1
low
=
log_sigmas
[
low_idx
]
high
=
log_sigmas
[
high_idx
]
# interpolate sigmas
w
=
(
low
-
log_sigma
)
/
(
low
-
high
)
w
=
np
.
clip
(
w
,
0
,
1
)
# transform interpolation to time range
t
=
(
1
-
w
)
*
low_idx
+
w
*
high_idx
t
=
t
.
reshape
(
sigma
.
shape
)
return
t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def
_sigma_to_alpha_sigma_t
(
self
,
sigma
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
config
.
use_flow_sigmas
:
alpha_t
=
1
-
sigma
sigma_t
=
sigma
else
:
alpha_t
=
1
/
((
sigma
**
2
+
1
)
**
0.5
)
sigma_t
=
sigma
*
alpha_t
return
alpha_t
,
sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def
_convert_to_karras
(
self
,
in_sigmas
:
torch
.
Tensor
,
num_inference_steps
)
->
torch
.
Tensor
:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if
hasattr
(
self
.
config
,
"sigma_min"
):
sigma_min
=
self
.
config
.
sigma_min
else
:
sigma_min
=
None
if
hasattr
(
self
.
config
,
"sigma_max"
):
sigma_max
=
self
.
config
.
sigma_max
else
:
sigma_max
=
None
sigma_min
=
sigma_min
if
sigma_min
is
not
None
else
in_sigmas
[
-
1
].
item
()
sigma_max
=
sigma_max
if
sigma_max
is
not
None
else
in_sigmas
[
0
].
item
()
rho
=
7.0
# 7.0 is the value used in the paper
ramp
=
np
.
linspace
(
0
,
1
,
num_inference_steps
)
min_inv_rho
=
sigma_min
**
(
1
/
rho
)
max_inv_rho
=
sigma_max
**
(
1
/
rho
)
sigmas
=
(
max_inv_rho
+
ramp
*
(
min_inv_rho
-
max_inv_rho
))
**
rho
return
sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def
_convert_to_exponential
(
self
,
in_sigmas
:
torch
.
Tensor
,
num_inference_steps
:
int
)
->
torch
.
Tensor
:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if
hasattr
(
self
.
config
,
"sigma_min"
):
sigma_min
=
self
.
config
.
sigma_min
else
:
sigma_min
=
None
if
hasattr
(
self
.
config
,
"sigma_max"
):
sigma_max
=
self
.
config
.
sigma_max
else
:
sigma_max
=
None
sigma_min
=
sigma_min
if
sigma_min
is
not
None
else
in_sigmas
[
-
1
].
item
()
sigma_max
=
sigma_max
if
sigma_max
is
not
None
else
in_sigmas
[
0
].
item
()
sigmas
=
np
.
exp
(
np
.
linspace
(
math
.
log
(
sigma_max
),
math
.
log
(
sigma_min
),
num_inference_steps
))
return
sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def
_convert_to_beta
(
self
,
in_sigmas
:
torch
.
Tensor
,
num_inference_steps
:
int
,
alpha
:
float
=
0.6
,
beta
:
float
=
0.6
)
->
torch
.
Tensor
:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if
hasattr
(
self
.
config
,
"sigma_min"
):
sigma_min
=
self
.
config
.
sigma_min
else
:
sigma_min
=
None
if
hasattr
(
self
.
config
,
"sigma_max"
):
sigma_max
=
self
.
config
.
sigma_max
else
:
sigma_max
=
None
sigma_min
=
sigma_min
if
sigma_min
is
not
None
else
in_sigmas
[
-
1
].
item
()
sigma_max
=
sigma_max
if
sigma_max
is
not
None
else
in_sigmas
[
0
].
item
()
sigmas
=
np
.
array
([
sigma_min
+
(
ppf
*
(
sigma_max
-
sigma_min
))
for
ppf
in
[
scipy
.
stats
.
beta
.
ppf
(
timestep
,
alpha
,
beta
)
for
timestep
in
1
-
np
.
linspace
(
0
,
1
,
num_inference_steps
)
]
])
return
sigmas
def
convert_model_output
(
self
,
model_output
:
torch
.
Tensor
,
*
args
,
sample
:
torch
.
Tensor
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
r
"""
Convert the model output to the corresponding type the UniPC algorithm needs.
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The converted model output.
"""
timestep
=
args
[
0
]
if
len
(
args
)
>
0
else
kwargs
.
pop
(
"timestep"
,
None
)
if
sample
is
None
:
if
len
(
args
)
>
1
:
sample
=
args
[
1
]
else
:
raise
ValueError
(
"missing `sample` as a required keyword argument"
)
if
timestep
is
not
None
:
deprecate
(
"timesteps"
,
"1.0.0"
,
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`"
,
)
sigma
=
self
.
sigmas
[
self
.
step_index
]
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma
)
if
self
.
predict_x0
:
if
self
.
config
.
prediction_type
==
"epsilon"
:
x0_pred
=
(
sample
-
sigma_t
*
model_output
)
/
alpha_t
elif
self
.
config
.
prediction_type
==
"sample"
:
x0_pred
=
model_output
elif
self
.
config
.
prediction_type
==
"v_prediction"
:
x0_pred
=
alpha_t
*
sample
-
sigma_t
*
model_output
elif
self
.
config
.
prediction_type
==
"flow_prediction"
:
sigma_t
=
self
.
sigmas
[
self
.
step_index
]
x0_pred
=
sample
-
sigma_t
*
model_output
else
:
raise
ValueError
(
f
"prediction_type given as
{
self
.
config
.
prediction_type
}
must be one of `epsilon`, `sample`, "
"`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
)
if
self
.
config
.
thresholding
:
x0_pred
=
self
.
_threshold_sample
(
x0_pred
)
return
x0_pred
else
:
if
self
.
config
.
prediction_type
==
"epsilon"
:
return
model_output
elif
self
.
config
.
prediction_type
==
"sample"
:
epsilon
=
(
sample
-
alpha_t
*
model_output
)
/
sigma_t
return
epsilon
elif
self
.
config
.
prediction_type
==
"v_prediction"
:
epsilon
=
alpha_t
*
model_output
+
sigma_t
*
sample
return
epsilon
else
:
raise
ValueError
(
f
"prediction_type given as
{
self
.
config
.
prediction_type
}
must be one of `epsilon`, `sample`, or"
" `v_prediction` for the UniPCMultistepScheduler."
)
def
multistep_uni_p_bh_update
(
self
,
model_output
:
torch
.
Tensor
,
*
args
,
sample
:
torch
.
Tensor
=
None
,
order
:
Optional
[
int
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model at the current timestep.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
order (`int`):
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
prev_timestep
=
args
[
0
]
if
len
(
args
)
>
0
else
kwargs
.
pop
(
"prev_timestep"
,
None
)
if
sample
is
None
:
if
len
(
args
)
>
1
:
sample
=
args
[
1
]
else
:
raise
ValueError
(
" missing `sample` as a required keyword argument"
)
if
order
is
None
:
if
len
(
args
)
>
2
:
order
=
args
[
2
]
else
:
raise
ValueError
(
" missing `order` as a required keyword argument"
)
if
prev_timestep
is
not
None
:
deprecate
(
"prev_timestep"
,
"1.0.0"
,
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`"
,
)
model_output_list
=
self
.
model_outputs
s0
=
self
.
timestep_list
[
-
1
]
m0
=
model_output_list
[
-
1
]
x
=
sample
if
self
.
solver_p
:
x_t
=
self
.
solver_p
.
step
(
model_output
,
s0
,
x
).
prev_sample
return
x_t
sigma_t
,
sigma_s0
=
self
.
sigmas
[
self
.
step_index
+
1
],
self
.
sigmas
[
self
.
step_index
]
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_t
)
alpha_s0
,
sigma_s0
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_s0
)
lambda_t
=
torch
.
log
(
alpha_t
)
-
torch
.
log
(
sigma_t
)
lambda_s0
=
torch
.
log
(
alpha_s0
)
-
torch
.
log
(
sigma_s0
)
h
=
lambda_t
-
lambda_s0
device
=
sample
.
device
rks
=
[]
D1s
=
[]
for
i
in
range
(
1
,
order
):
si
=
self
.
step_index
-
i
mi
:
torch
.
Tensor
=
model_output_list
[
-
(
i
+
1
)]
alpha_si
,
sigma_si
=
self
.
_sigma_to_alpha_sigma_t
(
self
.
sigmas
[
si
])
lambda_si
=
torch
.
log
(
alpha_si
)
-
torch
.
log
(
sigma_si
)
rk
=
(
lambda_si
-
lambda_s0
)
/
h
rks
.
append
(
rk
)
D1s
.
append
((
mi
-
m0
)
/
rk
)
rks
.
append
(
1.0
)
rks
=
torch
.
tensor
(
rks
,
device
=
device
)
R
=
[]
b
=
[]
hh
=
-
h
if
self
.
predict_x0
else
h
h_phi_1
=
torch
.
expm1
(
hh
)
# h\phi_1(h) = e^h - 1
h_phi_k
=
h_phi_1
/
hh
-
1
factorial_i
=
1
if
self
.
config
.
solver_type
==
"bh1"
:
B_h
=
hh
elif
self
.
config
.
solver_type
==
"bh2"
:
B_h
=
torch
.
expm1
(
hh
)
else
:
raise
NotImplementedError
()
for
i
in
range
(
1
,
order
+
1
):
R
.
append
(
torch
.
pow
(
rks
,
i
-
1
))
b
.
append
(
h_phi_k
*
factorial_i
/
B_h
)
factorial_i
*=
i
+
1
h_phi_k
=
h_phi_k
/
hh
-
1
/
factorial_i
R_tensor
:
torch
.
Tensor
=
torch
.
stack
(
R
)
b
=
torch
.
tensor
(
b
,
device
=
device
)
D1s_tensor
:
Optional
[
torch
.
Tensor
]
=
None
if
len
(
D1s
)
>
0
:
D1s_tensor
=
torch
.
stack
(
D1s
,
dim
=
1
)
# (B, K)
# for order 2, we use a simplified version
if
order
==
2
:
rhos_p
=
torch
.
tensor
([
0.5
],
dtype
=
x
.
dtype
,
device
=
device
)
else
:
rhos_p
=
torch
.
linalg
.
solve
(
R_tensor
[:
-
1
,
:
-
1
],
b
[:
-
1
]).
to
(
device
).
to
(
x
.
dtype
)
else
:
D1s_tensor
=
None
if
self
.
predict_x0
:
x_t_
=
sigma_t
/
sigma_s0
*
x
-
alpha_t
*
h_phi_1
*
m0
if
D1s_tensor
is
not
None
:
pred_res
=
torch
.
einsum
(
"k,bkc...->bc..."
,
rhos_p
,
D1s_tensor
)
else
:
pred_res
=
0
x_t
=
x_t_
-
alpha_t
*
B_h
*
pred_res
else
:
x_t_
=
alpha_t
/
alpha_s0
*
x
-
sigma_t
*
h_phi_1
*
m0
if
D1s_tensor
is
not
None
:
pred_res
=
torch
.
einsum
(
"k,bkc...->bc..."
,
rhos_p
,
D1s_tensor
)
else
:
pred_res
=
0
x_t
=
x_t_
-
sigma_t
*
B_h
*
pred_res
x_t
=
x_t
.
to
(
x
.
dtype
)
return
x_t
def
multistep_uni_c_bh_update
(
self
,
this_model_output
:
torch
.
Tensor
,
*
args
,
last_sample
:
Optional
[
torch
.
Tensor
]
=
None
,
this_sample
:
Optional
[
torch
.
Tensor
]
=
None
,
order
:
Optional
[
int
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""
One step for the UniC (B(h) version).
Args:
this_model_output (`torch.Tensor`):
The model outputs at `x_t`.
this_timestep (`int`):
The current timestep `t`.
last_sample (`torch.Tensor`):
The generated sample before the last predictor `x_{t-1}`.
this_sample (`torch.Tensor`):
The generated sample after the last predictor `x_{t}`.
order (`int`):
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
Returns:
`torch.Tensor`:
The corrected sample tensor at the current timestep.
"""
this_timestep
=
args
[
0
]
if
len
(
args
)
>
0
else
kwargs
.
pop
(
"this_timestep"
,
None
)
if
last_sample
is
None
:
if
len
(
args
)
>
1
:
last_sample
=
args
[
1
]
else
:
raise
ValueError
(
" missing`last_sample` as a required keyword argument"
)
if
this_sample
is
None
:
if
len
(
args
)
>
2
:
this_sample
=
args
[
2
]
else
:
raise
ValueError
(
" missing`this_sample` as a required keyword argument"
)
if
order
is
None
:
if
len
(
args
)
>
3
:
order
=
args
[
3
]
else
:
raise
ValueError
(
" missing`order` as a required keyword argument"
)
if
this_timestep
is
not
None
:
deprecate
(
"this_timestep"
,
"1.0.0"
,
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`"
,
)
model_output_list
=
self
.
model_outputs
m0
=
model_output_list
[
-
1
]
x
=
last_sample
x_t
=
this_sample
model_t
=
this_model_output
sigma_t
,
sigma_s0
=
self
.
sigmas
[
self
.
step_index
],
self
.
sigmas
[
self
.
step_index
-
1
]
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_t
)
alpha_s0
,
sigma_s0
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_s0
)
lambda_t
=
torch
.
log
(
alpha_t
)
-
torch
.
log
(
sigma_t
)
lambda_s0
=
torch
.
log
(
alpha_s0
)
-
torch
.
log
(
sigma_s0
)
h
=
lambda_t
-
lambda_s0
device
=
this_sample
.
device
rks
=
[]
D1s
=
[]
for
i
in
range
(
1
,
order
):
si
=
self
.
step_index
-
(
i
+
1
)
mi
:
torch
.
Tensor
=
model_output_list
[
-
(
i
+
1
)]
alpha_si
,
sigma_si
=
self
.
_sigma_to_alpha_sigma_t
(
self
.
sigmas
[
si
])
lambda_si
=
torch
.
log
(
alpha_si
)
-
torch
.
log
(
sigma_si
)
rk
=
(
lambda_si
-
lambda_s0
)
/
h
rks
.
append
(
rk
)
D1s
.
append
((
mi
-
m0
)
/
rk
)
rks
.
append
(
1.0
)
rks
=
torch
.
tensor
(
rks
,
device
=
device
)
R
=
[]
b
=
[]
hh
=
-
h
if
self
.
predict_x0
else
h
h_phi_1
=
torch
.
expm1
(
hh
)
# h\phi_1(h) = e^h - 1
h_phi_k
=
h_phi_1
/
hh
-
1
factorial_i
=
1
if
self
.
config
.
solver_type
==
"bh1"
:
B_h
=
hh
elif
self
.
config
.
solver_type
==
"bh2"
:
B_h
=
torch
.
expm1
(
hh
)
else
:
raise
NotImplementedError
()
for
i
in
range
(
1
,
order
+
1
):
R
.
append
(
torch
.
pow
(
rks
,
i
-
1
))
b
.
append
(
h_phi_k
*
factorial_i
/
B_h
)
factorial_i
*=
i
+
1
h_phi_k
=
h_phi_k
/
hh
-
1
/
factorial_i
R
=
torch
.
stack
(
R
)
b
=
torch
.
tensor
(
b
,
device
=
device
)
D1s_tensor
:
Optional
[
torch
.
Tensor
]
=
torch
.
stack
(
D1s
,
dim
=
1
)
if
len
(
D1s
)
>
0
else
None
# for order 1, we use a simplified version
if
order
==
1
:
rhos_c
=
torch
.
tensor
([
0.5
],
dtype
=
x
.
dtype
,
device
=
device
)
else
:
rhos_c
=
torch
.
linalg
.
solve
(
R
,
b
).
to
(
device
).
to
(
x
.
dtype
)
if
self
.
predict_x0
:
x_t_
=
sigma_t
/
sigma_s0
*
x
-
alpha_t
*
h_phi_1
*
m0
if
D1s_tensor
is
not
None
:
corr_res
=
torch
.
einsum
(
"k,bkc...->bc..."
,
rhos_c
[:
-
1
],
D1s_tensor
)
else
:
corr_res
=
0
D1_t
=
model_t
-
m0
x_t
=
x_t_
-
alpha_t
*
B_h
*
(
corr_res
+
rhos_c
[
-
1
]
*
D1_t
)
else
:
x_t_
=
alpha_t
/
alpha_s0
*
x
-
sigma_t
*
h_phi_1
*
m0
if
D1s_tensor
is
not
None
:
corr_res
=
torch
.
einsum
(
"k,bkc...->bc..."
,
rhos_c
[:
-
1
],
D1s_tensor
)
else
:
corr_res
=
0
D1_t
=
model_t
-
m0
x_t
=
x_t_
-
sigma_t
*
B_h
*
(
corr_res
+
rhos_c
[
-
1
]
*
D1_t
)
x_t
=
x_t
.
to
(
x
.
dtype
)
return
x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
)
->
int
:
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
index_candidates
=
(
schedule_timesteps
==
timestep
).
nonzero
()
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
self
.
timesteps
)
-
1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
].
item
()
else
:
step_index
=
index_candidates
[
0
].
item
()
return
step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
)
->
None
:
"""
Initialize the step_index counter for the scheduler.
"""
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
self
,
model_output
:
torch
.
Tensor
,
timestep
:
Union
[
int
,
torch
.
Tensor
],
sample
:
torch
.
Tensor
,
return_dict
:
bool
=
True
,
)
->
Union
[
SchedulerOutput
,
Tuple
]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
the multistep UniPC.
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if
self
.
num_inference_steps
is
None
:
raise
ValueError
(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
assert
self
.
step_index
is
not
None
use_corrector
=
(
self
.
step_index
>
0
and
self
.
step_index
-
1
not
in
self
.
disable_corrector
and
self
.
last_sample
is
not
None
)
model_output_convert
=
self
.
convert_model_output
(
model_output
,
sample
=
sample
)
if
use_corrector
:
sample
=
self
.
multistep_uni_c_bh_update
(
this_model_output
=
model_output_convert
,
last_sample
=
self
.
last_sample
,
this_sample
=
sample
,
order
=
self
.
this_order
,
)
for
i
in
range
(
self
.
config
.
solver_order
-
1
):
self
.
model_outputs
[
i
]
=
self
.
model_outputs
[
i
+
1
]
self
.
timestep_list
[
i
]
=
self
.
timestep_list
[
i
+
1
]
self
.
model_outputs
[
-
1
]
=
model_output_convert
self
.
timestep_list
[
-
1
]
=
timestep
if
self
.
config
.
lower_order_final
:
this_order
=
min
(
self
.
config
.
solver_order
,
len
(
self
.
timesteps
)
-
self
.
step_index
)
else
:
this_order
=
self
.
config
.
solver_order
self
.
this_order
:
int
=
min
(
this_order
,
self
.
lower_order_nums
+
1
)
# warmup for multistep
assert
self
.
this_order
>
0
self
.
last_sample
=
sample
prev_sample
=
self
.
multistep_uni_p_bh_update
(
model_output
=
model_output
,
# pass the original non-converted model output, in case solver-p is used
sample
=
sample
,
order
=
self
.
this_order
,
)
if
self
.
lower_order_nums
<
self
.
config
.
solver_order
:
self
.
lower_order_nums
+=
1
# upon completion increase step index by one
assert
self
.
_step_index
is
not
None
self
.
_step_index
+=
1
if
not
return_dict
:
return
(
prev_sample
,
)
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
def
scale_model_input
(
self
,
sample
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.Tensor`:
A scaled input sample.
"""
return
sample
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
Tensor
:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
# mps does not support float64
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
else
:
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
elif
self
.
step_index
is
not
None
:
# add_noise is called after first denoising step (for inpainting)
step_indices
=
[
self
.
step_index
]
*
timesteps
.
shape
[
0
]
else
:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
sigma
=
sigma
.
unsqueeze
(
-
1
)
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma
)
noisy_samples
=
alpha_t
*
original_samples
+
sigma_t
*
noise
return
noisy_samples
def
__len__
(
self
):
return
self
.
config
.
num_train_timesteps
FastVideo-main/fastvideo/v1/models/utils.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/utils.py
"""Utils for model executor."""
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
# TODO(PY): move it elsewhere
def
auto_attributes
(
init_func
):
"""
Decorator that automatically adds all initialization arguments as object attributes.
Example:
@auto_attributes
def __init__(self, a=1, b=2):
pass
# This will automatically set:
# - self.a = 1 and self.b = 2
# - self.config.a = 1 and self.config.b = 2
"""
def
wrapper
(
self
,
*
args
,
**
kwargs
):
# Get the function signature
import
inspect
signature
=
inspect
.
signature
(
init_func
)
parameters
=
signature
.
parameters
# Get parameter names (excluding 'self')
param_names
=
list
(
parameters
.
keys
())[
1
:]
# Bind arguments to parameters
bound_args
=
signature
.
bind
(
self
,
*
args
,
**
kwargs
)
bound_args
.
apply_defaults
()
# Create config object if it doesn't exist
if
not
hasattr
(
self
,
'config'
):
self
.
config
=
type
(
'Config'
,
(),
{})()
# Set attributes on self and self.config
for
name
in
param_names
:
if
name
in
bound_args
.
arguments
:
value
=
bound_args
.
arguments
[
name
]
setattr
(
self
,
name
,
value
)
setattr
(
self
.
config
,
name
,
value
)
# Call the original __init__ function
return
init_func
(
self
,
*
args
,
**
kwargs
)
return
wrapper
def
set_random_seed
(
seed
:
int
)
->
None
:
from
fastvideo.v1.platforms
import
current_platform
current_platform
.
seed_everything
(
seed
)
def
set_weight_attrs
(
weight
:
torch
.
Tensor
,
weight_attrs
:
Optional
[
Dict
[
str
,
Any
]],
):
"""Set attributes on a weight tensor.
This method is used to set attributes on a weight tensor. This method
will not overwrite existing attributes.
Args:
weight: The weight tensor.
weight_attrs: A dictionary of attributes to set on the weight tensor.
"""
if
weight_attrs
is
None
:
return
for
key
,
value
in
weight_attrs
.
items
():
assert
not
hasattr
(
weight
,
key
),
(
f
"Overwriting existing tensor attribute:
{
key
}
"
)
# NOTE(woosuk): During weight loading, we often do something like:
# narrowed_tensor = param.data.narrow(0, offset, len)
# narrowed_tensor.copy_(real_weight)
# expecting narrowed_tensor and param.data to share the same storage.
# However, on TPUs, narrowed_tensor will lazily propagate to the base
# tensor, which is param.data, leading to the redundant memory usage.
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
# TODO(woosuk): Remove this hack once we have a better solution.
from
fastvideo.v1.platforms
import
current_platform
if
current_platform
.
is_tpu
()
and
key
==
"weight_loader"
:
value
=
_make_synced_weight_loader
(
value
)
setattr
(
weight
,
key
,
value
)
def
_make_synced_weight_loader
(
original_weight_loader
)
->
Any
:
def
_synced_weight_loader
(
param
,
*
args
,
**
kwargs
):
original_weight_loader
(
param
,
*
args
,
**
kwargs
)
torch
.
_sync
(
param
)
return
_synced_weight_loader
def
extract_layer_index
(
layer_name
:
str
)
->
int
:
"""
Extract the layer index from the module name.
Examples:
- "encoder.layers.0" -> 0
- "encoder.layers.1.self_attn" -> 1
- "2.self_attn" -> 2
- "model.encoder.layers.0.sub.1" -> ValueError
"""
subnames
=
layer_name
.
split
(
"."
)
int_vals
:
List
[
int
]
=
[]
for
subname
in
subnames
:
try
:
int_vals
.
append
(
int
(
subname
))
except
ValueError
:
continue
assert
len
(
int_vals
)
==
1
,
(
f
"layer name
{
layer_name
}
should"
" only contain one integer"
)
return
int_vals
[
0
]
FastVideo-main/fastvideo/v1/models/vaes/common.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
math
import
prod
from
typing
import
Iterator
,
Optional
,
Tuple
,
Union
,
cast
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
diffusers.utils.torch_utils
import
randn_tensor
from
fastvideo.v1.configs.models
import
VAEConfig
from
fastvideo.v1.distributed
import
(
get_sequence_model_parallel_rank
,
get_sequence_model_parallel_world_size
)
class
ParallelTiledVAE
(
ABC
):
tile_sample_min_height
:
int
tile_sample_min_width
:
int
tile_sample_min_num_frames
:
int
tile_sample_stride_height
:
int
tile_sample_stride_width
:
int
tile_sample_stride_num_frames
:
int
blend_num_frames
:
int
use_tiling
:
bool
use_temporal_tiling
:
bool
use_parallel_tiling
:
bool
def
__init__
(
self
,
config
:
VAEConfig
,
**
kwargs
)
->
None
:
self
.
config
=
config
self
.
tile_sample_min_height
=
config
.
tile_sample_min_height
self
.
tile_sample_min_width
=
config
.
tile_sample_min_width
self
.
tile_sample_min_num_frames
=
config
.
tile_sample_min_num_frames
self
.
tile_sample_stride_height
=
config
.
tile_sample_stride_height
self
.
tile_sample_stride_width
=
config
.
tile_sample_stride_width
self
.
tile_sample_stride_num_frames
=
config
.
tile_sample_stride_num_frames
self
.
blend_num_frames
=
config
.
blend_num_frames
self
.
use_tiling
=
config
.
use_tiling
self
.
use_temporal_tiling
=
config
.
use_temporal_tiling
self
.
use_parallel_tiling
=
config
.
use_parallel_tiling
@
property
def
temporal_compression_ratio
(
self
)
->
int
:
return
cast
(
int
,
self
.
config
.
temporal_compression_ratio
)
@
property
def
spatial_compression_ratio
(
self
)
->
int
:
return
cast
(
int
,
self
.
config
.
spatial_compression_ratio
)
@
property
def
scaling_factor
(
self
)
->
Union
[
float
,
torch
.
tensor
]:
return
cast
(
Union
[
float
,
torch
.
tensor
],
self
.
config
.
scaling_factor
)
@
abstractmethod
def
_encode
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
pass
@
abstractmethod
def
_decode
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
pass
def
encode
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
num_channels
,
num_frames
,
height
,
width
=
x
.
shape
latent_num_frames
=
(
num_frames
-
1
)
//
self
.
temporal_compression_ratio
+
1
if
self
.
use_tiling
and
self
.
use_temporal_tiling
and
num_frames
>
self
.
tile_sample_min_num_frames
:
latents
=
self
.
tiled_encode
(
x
)[:,
:,
:
latent_num_frames
]
elif
self
.
use_tiling
and
(
width
>
self
.
tile_sample_min_width
or
height
>
self
.
tile_sample_min_height
):
latents
=
self
.
spatial_tiled_encode
(
x
)[:,
:,
:
latent_num_frames
]
else
:
latents
=
self
.
_encode
(
x
)[:,
:,
:
latent_num_frames
]
return
DiagonalGaussianDistribution
(
latents
)
def
decode
(
self
,
z
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
num_channels
,
num_frames
,
height
,
width
=
z
.
shape
tile_latent_min_height
=
self
.
tile_sample_min_height
//
self
.
spatial_compression_ratio
tile_latent_min_width
=
self
.
tile_sample_stride_width
//
self
.
spatial_compression_ratio
tile_latent_min_num_frames
=
self
.
tile_sample_min_num_frames
//
self
.
temporal_compression_ratio
num_sample_frames
=
(
num_frames
-
1
)
*
self
.
temporal_compression_ratio
+
1
if
self
.
use_tiling
and
self
.
use_parallel_tiling
and
get_sequence_model_parallel_world_size
(
)
>
1
:
return
self
.
parallel_tiled_decode
(
z
)[:,
:,
:
num_sample_frames
]
if
self
.
use_tiling
and
self
.
use_temporal_tiling
and
num_frames
>
tile_latent_min_num_frames
:
return
self
.
tiled_decode
(
z
)[:,
:,
:
num_sample_frames
]
if
self
.
use_tiling
and
(
width
>
tile_latent_min_width
or
height
>
tile_latent_min_height
):
return
self
.
spatial_tiled_decode
(
z
)[:,
:,
:
num_sample_frames
]
return
self
.
_decode
(
z
)[:,
:,
:
num_sample_frames
]
def
blend_v
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
blend_extent
:
int
)
->
torch
.
Tensor
:
blend_extent
=
min
(
a
.
shape
[
-
2
],
b
.
shape
[
-
2
],
blend_extent
)
for
y
in
range
(
blend_extent
):
b
[:,
:,
:,
y
,
:]
=
a
[:,
:,
:,
-
blend_extent
+
y
,
:]
*
(
1
-
y
/
blend_extent
)
+
b
[:,
:,
:,
y
,
:]
*
(
y
/
blend_extent
)
return
b
def
blend_h
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
blend_extent
:
int
)
->
torch
.
Tensor
:
blend_extent
=
min
(
a
.
shape
[
-
1
],
b
.
shape
[
-
1
],
blend_extent
)
for
x
in
range
(
blend_extent
):
b
[:,
:,
:,
:,
x
]
=
a
[:,
:,
:,
:,
-
blend_extent
+
x
]
*
(
1
-
x
/
blend_extent
)
+
b
[:,
:,
:,
:,
x
]
*
(
x
/
blend_extent
)
return
b
def
blend_t
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
blend_extent
:
int
)
->
torch
.
Tensor
:
blend_extent
=
min
(
a
.
shape
[
-
3
],
b
.
shape
[
-
3
],
blend_extent
)
for
x
in
range
(
blend_extent
):
b
[:,
:,
x
,
:,
:]
=
a
[:,
:,
-
blend_extent
+
x
,
:,
:]
*
(
1
-
x
/
blend_extent
)
+
b
[:,
:,
x
,
:,
:]
*
(
x
/
blend_extent
)
return
b
def
spatial_tiled_encode
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_
,
_
,
_
,
height
,
width
=
x
.
shape
# latent_height = height // self.spatial_compression_ratio
# latent_width = width // self.spatial_compression_ratio
tile_latent_min_height
=
self
.
tile_sample_min_height
//
self
.
spatial_compression_ratio
tile_latent_min_width
=
self
.
tile_sample_min_width
//
self
.
spatial_compression_ratio
tile_latent_stride_height
=
self
.
tile_sample_stride_height
//
self
.
spatial_compression_ratio
tile_latent_stride_width
=
self
.
tile_sample_stride_width
//
self
.
spatial_compression_ratio
blend_height
=
tile_latent_min_height
-
tile_latent_stride_height
blend_width
=
tile_latent_min_width
-
tile_latent_stride_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows
=
[]
for
i
in
range
(
0
,
height
,
self
.
tile_sample_stride_height
):
row
=
[]
for
j
in
range
(
0
,
width
,
self
.
tile_sample_stride_width
):
tile
=
x
[:,
:,
:,
i
:
i
+
self
.
tile_sample_min_height
,
j
:
j
+
self
.
tile_sample_min_width
]
tile
=
self
.
_encode
(
tile
)
row
.
append
(
tile
)
rows
.
append
(
row
)
return
self
.
_merge_spatial_tiles
(
rows
,
blend_height
,
blend_width
,
tile_latent_stride_height
,
tile_latent_stride_width
)
def
_parallel_data_generator
(
self
,
gathered_results
,
gathered_dim_metadata
)
->
Iterator
[
Tuple
[
torch
.
Tensor
,
int
]]:
global_idx
=
0
for
i
,
per_rank_metadata
in
enumerate
(
gathered_dim_metadata
):
_start_shape
=
0
for
shape
in
per_rank_metadata
:
mul_shape
=
prod
(
shape
)
yield
(
gathered_results
[
i
,
_start_shape
:
_start_shape
+
mul_shape
].
reshape
(
shape
),
global_idx
)
_start_shape
+=
mul_shape
global_idx
+=
1
def
parallel_tiled_decode
(
self
,
z
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
"""
Parallel version of tiled_decode that distributes both temporal and spatial computation across GPUs
"""
world_size
,
rank
=
get_sequence_model_parallel_world_size
(
),
get_sequence_model_parallel_rank
()
B
,
C
,
T
,
H
,
W
=
z
.
shape
# Calculate parameters
tile_latent_min_height
=
self
.
tile_sample_min_height
//
self
.
spatial_compression_ratio
tile_latent_min_width
=
self
.
tile_sample_min_width
//
self
.
spatial_compression_ratio
tile_latent_min_num_frames
=
self
.
tile_sample_min_num_frames
//
self
.
temporal_compression_ratio
tile_latent_stride_height
=
self
.
tile_sample_stride_height
//
self
.
spatial_compression_ratio
tile_latent_stride_width
=
self
.
tile_sample_stride_width
//
self
.
spatial_compression_ratio
tile_latent_stride_num_frames
=
self
.
tile_sample_stride_num_frames
//
self
.
temporal_compression_ratio
blend_height
=
self
.
tile_sample_min_height
-
self
.
tile_sample_stride_height
blend_width
=
self
.
tile_sample_min_width
-
self
.
tile_sample_stride_width
# Calculate tile dimensions
num_t_tiles
=
(
T
+
tile_latent_stride_num_frames
-
1
)
//
tile_latent_stride_num_frames
num_h_tiles
=
(
H
+
tile_latent_stride_height
-
1
)
//
tile_latent_stride_height
num_w_tiles
=
(
W
+
tile_latent_stride_width
-
1
)
//
tile_latent_stride_width
total_spatial_tiles
=
num_h_tiles
*
num_w_tiles
total_tiles
=
num_t_tiles
*
total_spatial_tiles
# Calculate tiles per rank and padding
tiles_per_rank
=
(
total_tiles
+
world_size
-
1
)
//
world_size
start_tile_idx
=
rank
*
tiles_per_rank
end_tile_idx
=
min
((
rank
+
1
)
*
tiles_per_rank
,
total_tiles
)
local_results
=
[]
local_dim_metadata
=
[]
# Process assigned tiles
for
local_idx
,
global_idx
in
enumerate
(
range
(
start_tile_idx
,
end_tile_idx
)):
t_idx
=
global_idx
//
total_spatial_tiles
spatial_idx
=
global_idx
%
total_spatial_tiles
h_idx
=
spatial_idx
//
num_w_tiles
w_idx
=
spatial_idx
%
num_w_tiles
# Calculate positions
t_start
=
t_idx
*
tile_latent_stride_num_frames
h_start
=
h_idx
*
tile_latent_stride_height
w_start
=
w_idx
*
tile_latent_stride_width
# Extract and process tile
tile
=
z
[:,
:,
t_start
:
t_start
+
tile_latent_min_num_frames
+
1
,
h_start
:
h_start
+
tile_latent_min_height
,
w_start
:
w_start
+
tile_latent_min_width
]
# Process tile
tile
=
self
.
_decode
(
tile
)
if
t_start
>
0
:
tile
=
tile
[:,
:,
1
:,
:,
:]
# Store metadata
shape
=
tile
.
shape
# Store decoded data (flattened)
decoded_flat
=
tile
.
reshape
(
-
1
)
local_results
.
append
(
decoded_flat
)
local_dim_metadata
.
append
(
shape
)
results
=
torch
.
cat
(
local_results
,
dim
=
0
).
contiguous
()
del
local_results
torch
.
cuda
.
empty_cache
()
# first gather size to pad the results
local_size
=
torch
.
tensor
([
results
.
size
(
0
)],
device
=
results
.
device
,
dtype
=
torch
.
int64
)
all_sizes
=
[
torch
.
zeros
(
1
,
device
=
results
.
device
,
dtype
=
torch
.
int64
)
for
_
in
range
(
world_size
)
]
dist
.
all_gather
(
all_sizes
,
local_size
)
max_size
=
max
(
size
.
item
()
for
size
in
all_sizes
)
padded_results
=
torch
.
zeros
(
max_size
,
device
=
results
.
device
)
padded_results
[:
results
.
size
(
0
)]
=
results
del
results
torch
.
cuda
.
empty_cache
()
# Gather all results
gathered_dim_metadata
=
[
None
]
*
world_size
gathered_results
=
torch
.
zeros_like
(
padded_results
).
repeat
(
world_size
,
*
[
1
]
*
len
(
padded_results
.
shape
)
).
contiguous
(
)
# use contiguous to make sure it won't copy data in the following operations
# TODO (PY): use fastvideo distributed methods
dist
.
all_gather_into_tensor
(
gathered_results
,
padded_results
)
dist
.
all_gather_object
(
gathered_dim_metadata
,
local_dim_metadata
)
# Process gathered results
data
:
list
=
[[[[]
for
_
in
range
(
num_w_tiles
)]
for
_
in
range
(
num_h_tiles
)]
for
_
in
range
(
num_t_tiles
)]
for
current_data
,
global_idx
in
self
.
_parallel_data_generator
(
gathered_results
,
gathered_dim_metadata
):
t_idx
=
global_idx
//
total_spatial_tiles
spatial_idx
=
global_idx
%
total_spatial_tiles
h_idx
=
spatial_idx
//
num_w_tiles
w_idx
=
spatial_idx
%
num_w_tiles
data
[
t_idx
][
h_idx
][
w_idx
]
=
current_data
# Merge results
result_slices
=
[]
last_slice_data
=
None
for
i
,
tem_data
in
enumerate
(
data
):
slice_data
=
self
.
_merge_spatial_tiles
(
tem_data
,
blend_height
,
blend_width
,
self
.
tile_sample_stride_height
,
self
.
tile_sample_stride_width
)
if
i
>
0
:
slice_data
=
self
.
blend_t
(
last_slice_data
,
slice_data
,
self
.
blend_num_frames
)
result_slices
.
append
(
slice_data
[:,
:,
:
self
.
tile_sample_stride_num_frames
,
:,
:])
else
:
result_slices
.
append
(
slice_data
[:,
:,
:
self
.
tile_sample_stride_num_frames
+
1
,
:,
:])
last_slice_data
=
slice_data
dec
=
torch
.
cat
(
result_slices
,
dim
=
2
)
return
dec
def
_merge_spatial_tiles
(
self
,
tiles
,
blend_height
,
blend_width
,
stride_height
,
stride_width
)
->
torch
.
Tensor
:
"""Helper function to merge spatial tiles with blending"""
result_rows
=
[]
for
i
,
row
in
enumerate
(
tiles
):
result_row
=
[]
for
j
,
tile
in
enumerate
(
row
):
if
i
>
0
:
tile
=
self
.
blend_v
(
tiles
[
i
-
1
][
j
],
tile
,
blend_height
)
if
j
>
0
:
tile
=
self
.
blend_h
(
row
[
j
-
1
],
tile
,
blend_width
)
result_row
.
append
(
tile
[:,
:,
:,
:
stride_height
,
:
stride_width
])
result_rows
.
append
(
torch
.
cat
(
result_row
,
dim
=-
1
))
return
torch
.
cat
(
result_rows
,
dim
=-
2
)
def
spatial_tiled_decode
(
self
,
z
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
Returns:
`torch.Tensor`:
The decoded images.
"""
_
,
_
,
_
,
height
,
width
=
z
.
shape
# sample_height = height * self.spatial_compression_ratio
# sample_width = width * self.spatial_compression_ratio
tile_latent_min_height
=
self
.
tile_sample_min_height
//
self
.
spatial_compression_ratio
tile_latent_min_width
=
self
.
tile_sample_min_width
//
self
.
spatial_compression_ratio
tile_latent_stride_height
=
self
.
tile_sample_stride_height
//
self
.
spatial_compression_ratio
tile_latent_stride_width
=
self
.
tile_sample_stride_width
//
self
.
spatial_compression_ratio
blend_height
=
self
.
tile_sample_min_height
-
self
.
tile_sample_stride_height
blend_width
=
self
.
tile_sample_min_width
-
self
.
tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows
=
[]
for
i
in
range
(
0
,
height
,
tile_latent_stride_height
):
row
=
[]
for
j
in
range
(
0
,
width
,
tile_latent_stride_width
):
tile
=
z
[:,
:,
:,
i
:
i
+
tile_latent_min_height
,
j
:
j
+
tile_latent_min_width
]
decoded
=
self
.
_decode
(
tile
)
row
.
append
(
decoded
)
rows
.
append
(
row
)
return
self
.
_merge_spatial_tiles
(
rows
,
blend_height
,
blend_width
,
self
.
tile_sample_stride_height
,
self
.
tile_sample_stride_width
)
def
tiled_encode
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
_
,
_
,
num_frames
,
height
,
width
=
x
.
shape
# tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
tile_latent_stride_num_frames
=
self
.
tile_sample_stride_num_frames
//
self
.
temporal_compression_ratio
row
=
[]
for
i
in
range
(
0
,
num_frames
,
self
.
tile_sample_stride_num_frames
):
tile
=
x
[:,
:,
i
:
i
+
self
.
tile_sample_min_num_frames
+
1
,
:,
:]
if
self
.
use_tiling
and
(
height
>
self
.
tile_sample_min_height
or
width
>
self
.
tile_sample_min_width
):
tile
=
self
.
spatial_tiled_encode
(
tile
)
else
:
tile
=
self
.
_encode
(
tile
)
if
i
>
0
:
tile
=
tile
[:,
:,
1
:,
:,
:]
row
.
append
(
tile
)
result_row
=
[]
for
i
,
tile
in
enumerate
(
row
):
if
i
>
0
:
tile
=
self
.
blend_t
(
row
[
i
-
1
],
tile
,
self
.
blend_num_frames
)
result_row
.
append
(
tile
[:,
:,
:
tile_latent_stride_num_frames
,
:,
:])
else
:
result_row
.
append
(
tile
[:,
:,
:
tile_latent_stride_num_frames
+
1
,
:,
:])
enc
=
torch
.
cat
(
result_row
,
dim
=
2
)
return
enc
def
tiled_decode
(
self
,
z
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
num_channels
,
num_frames
,
height
,
width
=
z
.
shape
tile_latent_min_height
=
self
.
tile_sample_min_height
//
self
.
spatial_compression_ratio
tile_latent_min_width
=
self
.
tile_sample_min_width
//
self
.
spatial_compression_ratio
tile_latent_min_num_frames
=
self
.
tile_sample_min_num_frames
//
self
.
temporal_compression_ratio
tile_latent_stride_num_frames
=
self
.
tile_sample_stride_num_frames
//
self
.
temporal_compression_ratio
row
=
[]
for
i
in
range
(
0
,
num_frames
,
tile_latent_stride_num_frames
):
tile
=
z
[:,
:,
i
:
i
+
tile_latent_min_num_frames
+
1
,
:,
:]
if
self
.
use_tiling
and
(
tile
.
shape
[
-
1
]
>
tile_latent_min_width
or
tile
.
shape
[
-
2
]
>
tile_latent_min_height
):
decoded
=
self
.
spatial_tiled_decode
(
tile
)
else
:
decoded
=
self
.
_decode
(
tile
)
if
i
>
0
:
decoded
=
decoded
[:,
:,
1
:,
:,
:]
row
.
append
(
decoded
)
result_row
=
[]
for
i
,
tile
in
enumerate
(
row
):
if
i
>
0
:
tile
=
self
.
blend_t
(
row
[
i
-
1
],
tile
,
self
.
blend_num_frames
)
result_row
.
append
(
tile
[:,
:,
:
self
.
tile_sample_stride_num_frames
,
:,
:])
else
:
result_row
.
append
(
tile
[:,
:,
:
self
.
tile_sample_stride_num_frames
+
1
,
:,
:])
dec
=
torch
.
cat
(
result_row
,
dim
=
2
)
return
dec
def
enable_tiling
(
self
,
tile_sample_min_height
:
Optional
[
int
]
=
None
,
tile_sample_min_width
:
Optional
[
int
]
=
None
,
tile_sample_min_num_frames
:
Optional
[
int
]
=
None
,
tile_sample_stride_height
:
Optional
[
int
]
=
None
,
tile_sample_stride_width
:
Optional
[
int
]
=
None
,
tile_sample_stride_num_frames
:
Optional
[
int
]
=
None
,
blend_num_frames
:
Optional
[
int
]
=
None
,
use_tiling
:
Optional
[
bool
]
=
None
,
use_temporal_tiling
:
Optional
[
bool
]
=
None
,
use_parallel_tiling
:
Optional
[
bool
]
=
None
,
)
->
None
:
r
"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_min_num_frames (`int`, *optional*):
The minimum number of frames required for a sample to be separated into tiles across the frame
dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
tile_sample_stride_num_frames (`int`, *optional*):
The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts
produced across the frame dimension.
"""
self
.
use_tiling
=
True
self
.
tile_sample_min_height
=
tile_sample_min_height
or
self
.
tile_sample_min_height
self
.
tile_sample_min_width
=
tile_sample_min_width
or
self
.
tile_sample_min_width
self
.
tile_sample_min_num_frames
=
tile_sample_min_num_frames
or
self
.
tile_sample_min_num_frames
self
.
tile_sample_stride_height
=
tile_sample_stride_height
or
self
.
tile_sample_stride_height
self
.
tile_sample_stride_width
=
tile_sample_stride_width
or
self
.
tile_sample_stride_width
self
.
tile_sample_stride_num_frames
=
tile_sample_stride_num_frames
or
self
.
tile_sample_stride_num_frames
if
blend_num_frames
is
not
None
:
self
.
blend_num_frames
=
blend_num_frames
else
:
self
.
blend_num_frames
=
self
.
tile_sample_min_num_frames
-
self
.
tile_sample_stride_num_frames
self
.
use_tiling
=
use_tiling
or
self
.
use_tiling
self
.
use_temporal_tiling
=
use_temporal_tiling
or
self
.
use_temporal_tiling
self
.
use_parallel_tiling
=
use_parallel_tiling
or
self
.
use_parallel_tiling
def
disable_tiling
(
self
)
->
None
:
r
"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self
.
use_tiling
=
False
# adapted from https://github.com/huggingface/diffusers/blob/e7ffeae0a191f710881d1fbde00cd6ff025e81f2/src/diffusers/models/autoencoders/vae.py#L691
class
DiagonalGaussianDistribution
:
def
__init__
(
self
,
parameters
:
torch
.
Tensor
,
deterministic
:
bool
=
False
):
self
.
parameters
=
parameters
self
.
mean
,
self
.
logvar
=
torch
.
chunk
(
parameters
,
2
,
dim
=
1
)
self
.
logvar
=
torch
.
clamp
(
self
.
logvar
,
-
30.0
,
20.0
)
self
.
deterministic
=
deterministic
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
if
self
.
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
,
device
=
self
.
parameters
.
device
,
dtype
=
self
.
parameters
.
dtype
)
def
sample
(
self
,
generator
:
Optional
[
torch
.
Generator
]
=
None
)
->
torch
.
Tensor
:
# make sure sample is on the same device as the parameters and has same dtype
sample
=
randn_tensor
(
self
.
mean
.
shape
,
generator
=
generator
,
device
=
self
.
parameters
.
device
,
dtype
=
self
.
parameters
.
dtype
,
)
x
=
self
.
mean
+
self
.
std
*
sample
return
x
def
kl
(
self
,
other
:
Optional
[
"DiagonalGaussianDistribution"
]
=
None
)
->
torch
.
Tensor
:
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
else
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
def
nll
(
self
,
sample
:
torch
.
Tensor
,
dims
:
Tuple
[
int
,
...]
=
(
1
,
2
,
3
))
->
torch
.
Tensor
:
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
,
)
def
mode
(
self
)
->
torch
.
Tensor
:
return
self
.
mean
FastVideo-main/fastvideo/v1/models/vaes/hunyuanvae.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from diffusers
# Copyright 2024 The Hunyuan Team, The HuggingFace Team and The FastVideo Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fastvideo.v1.configs.models.vaes
import
HunyuanVAEConfig
from
fastvideo.v1.layers.activation
import
get_act_fn
from
fastvideo.v1.models.vaes.common
import
ParallelTiledVAE
def
prepare_causal_attention_mask
(
num_frames
:
int
,
height_width
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
batch_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
indices
=
torch
.
arange
(
1
,
num_frames
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
indices_blocks
=
indices
.
repeat_interleave
(
height_width
)
x
,
y
=
torch
.
meshgrid
(
indices_blocks
,
indices_blocks
,
indexing
=
"xy"
)
mask
=
torch
.
where
(
x
<=
y
,
0
,
-
float
(
"inf"
)).
to
(
dtype
=
dtype
)
if
batch_size
is
not
None
:
mask
=
mask
.
unsqueeze
(
0
).
expand
(
batch_size
,
-
1
,
-
1
)
return
mask
class
HunyuanVAEAttention
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
heads
,
dim_head
,
eps
,
norm_num_groups
,
bias
)
->
None
:
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
eps
=
eps
self
.
norm_num_groups
=
norm_num_groups
self
.
bias
=
bias
inner_dim
=
heads
*
dim_head
# Define the projection layers
self
.
to_q
=
nn
.
Linear
(
in_channels
,
inner_dim
,
bias
=
bias
)
self
.
to_k
=
nn
.
Linear
(
in_channels
,
inner_dim
,
bias
=
bias
)
self
.
to_v
=
nn
.
Linear
(
in_channels
,
inner_dim
,
bias
=
bias
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
in_channels
,
bias
=
bias
))
# Optional normalization layers
self
.
group_norm
=
nn
.
GroupNorm
(
norm_num_groups
,
in_channels
,
eps
=
eps
,
affine
=
True
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
residual
=
hidden_states
batch_size
,
sequence_length
,
_
=
hidden_states
.
shape
hidden_states
=
self
.
group_norm
(
hidden_states
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
# Project to query, key, value
query
=
self
.
to_q
(
hidden_states
)
key
=
self
.
to_k
(
hidden_states
)
value
=
self
.
to_v
(
hidden_states
)
# Reshape for multi-head attention
head_dim
=
self
.
dim_head
query
=
query
.
view
(
batch_size
,
-
1
,
self
.
heads
,
head_dim
).
transpose
(
1
,
2
)
key
=
key
.
view
(
batch_size
,
-
1
,
self
.
heads
,
head_dim
).
transpose
(
1
,
2
)
value
=
value
.
view
(
batch_size
,
-
1
,
self
.
heads
,
head_dim
).
transpose
(
1
,
2
)
# Perform scaled dot-product attention
hidden_states
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
attn_mask
=
attention_mask
,
dropout_p
=
0.0
,
is_causal
=
False
)
# Reshape back
hidden_states
=
hidden_states
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
-
1
,
self
.
heads
*
head_dim
)
hidden_states
=
hidden_states
.
to
(
query
.
dtype
)
# Linear projection
hidden_states
=
self
.
to_out
(
hidden_states
)
# Residual connection and rescale
hidden_states
=
hidden_states
+
residual
return
hidden_states
class
HunyuanVideoCausalConv3d
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
Union
[
int
,
Tuple
[
int
,
int
,
int
]]
=
3
,
stride
:
Union
[
int
,
Tuple
[
int
,
int
,
int
]]
=
1
,
padding
:
Union
[
int
,
Tuple
[
int
,
int
,
int
]]
=
0
,
dilation
:
Union
[
int
,
Tuple
[
int
,
int
,
int
]]
=
1
,
bias
:
bool
=
True
,
pad_mode
:
str
=
"replicate"
,
)
->
None
:
super
().
__init__
()
kernel_size
=
(
kernel_size
,
kernel_size
,
kernel_size
)
if
isinstance
(
kernel_size
,
int
)
else
kernel_size
self
.
pad_mode
=
pad_mode
self
.
time_causal_padding
=
(
kernel_size
[
0
]
//
2
,
kernel_size
[
0
]
//
2
,
kernel_size
[
1
]
//
2
,
kernel_size
[
1
]
//
2
,
kernel_size
[
2
]
-
1
,
0
,
)
self
.
conv
=
nn
.
Conv3d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
bias
=
bias
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
F
.
pad
(
hidden_states
,
self
.
time_causal_padding
,
mode
=
self
.
pad_mode
)
return
self
.
conv
(
hidden_states
)
class
HunyuanVideoUpsampleCausal3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
Optional
[
int
]
=
None
,
kernel_size
:
int
=
3
,
stride
:
int
=
1
,
bias
:
bool
=
True
,
upsample_factor
:
Tuple
[
int
,
...]
=
(
2
,
2
,
2
),
)
->
None
:
super
().
__init__
()
out_channels
=
out_channels
or
in_channels
self
.
upsample_factor
=
upsample_factor
self
.
conv
=
HunyuanVideoCausalConv3d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
bias
=
bias
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_frames
=
hidden_states
.
size
(
2
)
first_frame
,
other_frames
=
hidden_states
.
split
((
1
,
num_frames
-
1
),
dim
=
2
)
first_frame
=
F
.
interpolate
(
first_frame
.
squeeze
(
2
),
scale_factor
=
self
.
upsample_factor
[
1
:],
mode
=
"nearest"
).
unsqueeze
(
2
)
if
num_frames
>
1
:
# See: https://github.com/pytorch/pytorch/issues/81665
# Unless you have a version of pytorch where non-contiguous implementation of F.interpolate
# is fixed, this will raise either a runtime error, or fail silently with bad outputs.
# If you are encountering an error here, make sure to try running encoding/decoding with
# `vae.enable_tiling()` first. If that doesn't work, open an issue at:
# https://github.com/huggingface/diffusers/issues
other_frames
=
other_frames
.
contiguous
()
other_frames
=
F
.
interpolate
(
other_frames
,
scale_factor
=
self
.
upsample_factor
,
mode
=
"nearest"
)
hidden_states
=
torch
.
cat
((
first_frame
,
other_frames
),
dim
=
2
)
else
:
hidden_states
=
first_frame
hidden_states
=
self
.
conv
(
hidden_states
)
return
hidden_states
class
HunyuanVideoDownsampleCausal3D
(
nn
.
Module
):
def
__init__
(
self
,
channels
:
int
,
out_channels
:
Optional
[
int
]
=
None
,
padding
:
int
=
1
,
kernel_size
:
int
=
3
,
bias
:
bool
=
True
,
stride
=
2
,
)
->
None
:
super
().
__init__
()
out_channels
=
out_channels
or
channels
self
.
conv
=
HunyuanVideoCausalConv3d
(
channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
bias
=
bias
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
conv
(
hidden_states
)
return
hidden_states
class
HunyuanVideoResnetBlockCausal3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
Optional
[
int
]
=
None
,
dropout
:
float
=
0.0
,
groups
:
int
=
32
,
eps
:
float
=
1e-6
,
non_linearity
:
str
=
"silu"
,
)
->
None
:
super
().
__init__
()
out_channels
=
out_channels
or
in_channels
self
.
nonlinearity
=
get_act_fn
(
non_linearity
)
self
.
norm1
=
nn
.
GroupNorm
(
groups
,
in_channels
,
eps
=
eps
,
affine
=
True
)
self
.
conv1
=
HunyuanVideoCausalConv3d
(
in_channels
,
out_channels
,
3
,
1
,
0
)
self
.
norm2
=
nn
.
GroupNorm
(
groups
,
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
conv2
=
HunyuanVideoCausalConv3d
(
out_channels
,
out_channels
,
3
,
1
,
0
)
self
.
conv_shortcut
=
None
if
in_channels
!=
out_channels
:
self
.
conv_shortcut
=
HunyuanVideoCausalConv3d
(
in_channels
,
out_channels
,
1
,
1
,
0
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
hidden_states
.
contiguous
()
residual
=
hidden_states
hidden_states
=
self
.
norm1
(
hidden_states
)
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
hidden_states
=
self
.
conv1
(
hidden_states
)
hidden_states
=
self
.
norm2
(
hidden_states
)
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
conv2
(
hidden_states
)
if
self
.
conv_shortcut
is
not
None
:
residual
=
self
.
conv_shortcut
(
residual
)
hidden_states
=
hidden_states
+
residual
return
hidden_states
class
HunyuanVideoMidBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_act_fn
:
str
=
"silu"
,
resnet_groups
:
int
=
32
,
add_attention
:
bool
=
True
,
attention_head_dim
:
int
=
1
,
)
->
None
:
super
().
__init__
()
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
self
.
add_attention
=
add_attention
# There is always at least one resnet
resnets
=
[
HunyuanVideoResnetBlockCausal3D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
non_linearity
=
resnet_act_fn
,
)
]
attentions
:
list
[
Optional
[
HunyuanVAEAttention
]]
=
[]
for
_
in
range
(
num_layers
):
if
self
.
add_attention
:
attentions
.
append
(
HunyuanVAEAttention
(
in_channels
,
heads
=
in_channels
//
attention_head_dim
,
dim_head
=
attention_head_dim
,
eps
=
resnet_eps
,
norm_num_groups
=
resnet_groups
,
bias
=
True
,
))
else
:
attentions
.
append
(
None
)
resnets
.
append
(
HunyuanVideoResnetBlockCausal3D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
non_linearity
=
resnet_act_fn
,
))
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
hidden_states
=
self
.
_gradient_checkpointing_func
(
self
.
resnets
[
0
],
hidden_states
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
if
attn
is
not
None
:
batch_size
,
num_channels
,
num_frames
,
height
,
width
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
3
,
4
,
1
).
flatten
(
1
,
3
)
attention_mask
=
prepare_causal_attention_mask
(
num_frames
,
height
*
width
,
hidden_states
.
dtype
,
hidden_states
.
device
,
batch_size
=
batch_size
)
hidden_states
=
attn
(
hidden_states
,
attention_mask
=
attention_mask
)
hidden_states
=
hidden_states
.
unflatten
(
1
,
(
num_frames
,
height
,
width
)).
permute
(
0
,
4
,
1
,
2
,
3
)
hidden_states
=
self
.
_gradient_checkpointing_func
(
resnet
,
hidden_states
)
else
:
hidden_states
=
self
.
resnets
[
0
](
hidden_states
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
if
attn
is
not
None
:
batch_size
,
num_channels
,
num_frames
,
height
,
width
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
3
,
4
,
1
).
flatten
(
1
,
3
)
attention_mask
=
prepare_causal_attention_mask
(
num_frames
,
height
*
width
,
hidden_states
.
dtype
,
hidden_states
.
device
,
batch_size
=
batch_size
)
hidden_states
=
attn
(
hidden_states
,
attention_mask
=
attention_mask
)
hidden_states
=
hidden_states
.
unflatten
(
1
,
(
num_frames
,
height
,
width
)).
permute
(
0
,
4
,
1
,
2
,
3
)
hidden_states
=
resnet
(
hidden_states
)
return
hidden_states
class
HunyuanVideoDownBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_act_fn
:
str
=
"silu"
,
resnet_groups
:
int
=
32
,
add_downsample
:
bool
=
True
,
downsample_stride
:
Tuple
[
int
,
...]
|
int
=
2
,
downsample_padding
:
int
=
1
,
)
->
None
:
super
().
__init__
()
resnets
=
[]
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
HunyuanVideoResnetBlockCausal3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
non_linearity
=
resnet_act_fn
,
))
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
([
HunyuanVideoDownsampleCausal3D
(
out_channels
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
stride
=
downsample_stride
,
)
])
else
:
self
.
downsamplers
=
None
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
for
resnet
in
self
.
resnets
:
hidden_states
=
self
.
_gradient_checkpointing_func
(
resnet
,
hidden_states
)
else
:
for
resnet
in
self
.
resnets
:
hidden_states
=
resnet
(
hidden_states
)
if
self
.
downsamplers
is
not
None
:
for
downsampler
in
self
.
downsamplers
:
hidden_states
=
downsampler
(
hidden_states
)
return
hidden_states
class
HunyuanVideoUpBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_act_fn
:
str
=
"silu"
,
resnet_groups
:
int
=
32
,
add_upsample
:
bool
=
True
,
upsample_scale_factor
:
Tuple
[
int
,
...]
=
(
2
,
2
,
2
),
)
->
None
:
super
().
__init__
()
resnets
=
[]
for
i
in
range
(
num_layers
):
input_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
HunyuanVideoResnetBlockCausal3D
(
in_channels
=
input_channels
,
out_channels
=
out_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
non_linearity
=
resnet_act_fn
,
))
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
HunyuanVideoUpsampleCausal3D
(
out_channels
,
out_channels
=
out_channels
,
upsample_factor
=
upsample_scale_factor
,
)
])
else
:
self
.
upsamplers
=
None
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
for
resnet
in
self
.
resnets
:
hidden_states
=
self
.
_gradient_checkpointing_func
(
resnet
,
hidden_states
)
else
:
for
resnet
in
self
.
resnets
:
hidden_states
=
resnet
(
hidden_states
)
if
self
.
upsamplers
is
not
None
:
for
upsampler
in
self
.
upsamplers
:
hidden_states
=
upsampler
(
hidden_states
)
return
hidden_states
class
HunyuanVideoEncoder3D
(
nn
.
Module
):
r
"""
Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
"""
def
__init__
(
self
,
in_channels
:
int
=
3
,
out_channels
:
int
=
3
,
down_block_types
:
Tuple
[
str
,
...]
=
(
"HunyuanVideoDownBlock3D"
,
"HunyuanVideoDownBlock3D"
,
"HunyuanVideoDownBlock3D"
,
"HunyuanVideoDownBlock3D"
,
),
block_out_channels
:
Tuple
[
int
,
...]
=
(
128
,
256
,
512
,
512
),
layers_per_block
:
int
=
2
,
norm_num_groups
:
int
=
32
,
act_fn
:
str
=
"silu"
,
double_z
:
bool
=
True
,
mid_block_add_attention
=
True
,
temporal_compression_ratio
:
int
=
4
,
spatial_compression_ratio
:
int
=
8
,
)
->
None
:
super
().
__init__
()
self
.
conv_in
=
HunyuanVideoCausalConv3d
(
in_channels
,
block_out_channels
[
0
],
kernel_size
=
3
,
stride
=
1
)
self
.
mid_block
:
Optional
[
HunyuanVideoMidBlock3D
]
=
None
self
.
down_blocks
=
nn
.
ModuleList
([])
output_channel
=
block_out_channels
[
0
]
for
i
,
down_block_type
in
enumerate
(
down_block_types
):
if
down_block_type
!=
"HunyuanVideoDownBlock3D"
:
raise
ValueError
(
f
"Unsupported down_block_type:
{
down_block_type
}
"
)
input_channel
=
output_channel
output_channel
=
block_out_channels
[
i
]
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
num_spatial_downsample_layers
=
int
(
np
.
log2
(
spatial_compression_ratio
))
num_time_downsample_layers
=
int
(
np
.
log2
(
temporal_compression_ratio
))
if
temporal_compression_ratio
==
4
:
add_spatial_downsample
=
bool
(
i
<
num_spatial_downsample_layers
)
add_time_downsample
=
bool
(
i
>=
(
len
(
block_out_channels
)
-
1
-
num_time_downsample_layers
)
and
not
is_final_block
)
elif
temporal_compression_ratio
==
8
:
add_spatial_downsample
=
bool
(
i
<
num_spatial_downsample_layers
)
add_time_downsample
=
bool
(
i
<
num_time_downsample_layers
)
else
:
raise
ValueError
(
f
"Unsupported time_compression_ratio:
{
temporal_compression_ratio
}
"
)
downsample_stride_HW
=
(
2
,
2
)
if
add_spatial_downsample
else
(
1
,
1
)
downsample_stride_T
=
(
2
,
)
if
add_time_downsample
else
(
1
,
)
downsample_stride
=
tuple
(
downsample_stride_T
+
downsample_stride_HW
)
down_block
=
HunyuanVideoDownBlock3D
(
num_layers
=
layers_per_block
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
add_downsample
=
bool
(
add_spatial_downsample
or
add_time_downsample
),
resnet_eps
=
1e-6
,
resnet_act_fn
=
act_fn
,
resnet_groups
=
norm_num_groups
,
downsample_stride
=
downsample_stride
,
downsample_padding
=
0
,
)
self
.
down_blocks
.
append
(
down_block
)
self
.
mid_block
=
HunyuanVideoMidBlock3D
(
in_channels
=
block_out_channels
[
-
1
],
resnet_eps
=
1e-6
,
resnet_act_fn
=
act_fn
,
attention_head_dim
=
block_out_channels
[
-
1
],
resnet_groups
=
norm_num_groups
,
add_attention
=
mid_block_add_attention
,
)
self
.
conv_norm_out
=
nn
.
GroupNorm
(
num_channels
=
block_out_channels
[
-
1
],
num_groups
=
norm_num_groups
,
eps
=
1e-6
)
self
.
conv_act
=
nn
.
SiLU
()
conv_out_channels
=
2
*
out_channels
if
double_z
else
out_channels
self
.
conv_out
=
HunyuanVideoCausalConv3d
(
block_out_channels
[
-
1
],
conv_out_channels
,
kernel_size
=
3
)
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
conv_in
(
hidden_states
)
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
for
down_block
in
self
.
down_blocks
:
hidden_states
=
self
.
_gradient_checkpointing_func
(
down_block
,
hidden_states
)
hidden_states
=
self
.
_gradient_checkpointing_func
(
self
.
mid_block
,
hidden_states
)
else
:
for
down_block
in
self
.
down_blocks
:
hidden_states
=
down_block
(
hidden_states
)
assert
self
.
mid_block
is
not
None
hidden_states
=
self
.
mid_block
(
hidden_states
)
hidden_states
=
self
.
conv_norm_out
(
hidden_states
)
hidden_states
=
self
.
conv_act
(
hidden_states
)
hidden_states
=
self
.
conv_out
(
hidden_states
)
return
hidden_states
class
HunyuanVideoDecoder3D
(
nn
.
Module
):
r
"""
Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
"""
def
__init__
(
self
,
in_channels
:
int
=
3
,
out_channels
:
int
=
3
,
up_block_types
:
Tuple
[
str
,
...]
=
(
"HunyuanVideoUpBlock3D"
,
"HunyuanVideoUpBlock3D"
,
"HunyuanVideoUpBlock3D"
,
"HunyuanVideoUpBlock3D"
,
),
block_out_channels
:
Tuple
[
int
,
...]
=
(
128
,
256
,
512
,
512
),
layers_per_block
:
int
=
2
,
norm_num_groups
:
int
=
32
,
act_fn
:
str
=
"silu"
,
mid_block_add_attention
=
True
,
time_compression_ratio
:
int
=
4
,
spatial_compression_ratio
:
int
=
8
,
):
super
().
__init__
()
self
.
layers_per_block
=
layers_per_block
self
.
conv_in
=
HunyuanVideoCausalConv3d
(
in_channels
,
block_out_channels
[
-
1
],
kernel_size
=
3
,
stride
=
1
)
self
.
up_blocks
=
nn
.
ModuleList
([])
# mid
self
.
mid_block
=
HunyuanVideoMidBlock3D
(
in_channels
=
block_out_channels
[
-
1
],
resnet_eps
=
1e-6
,
resnet_act_fn
=
act_fn
,
attention_head_dim
=
block_out_channels
[
-
1
],
resnet_groups
=
norm_num_groups
,
add_attention
=
mid_block_add_attention
,
)
# up
reversed_block_out_channels
=
list
(
reversed
(
block_out_channels
))
output_channel
=
reversed_block_out_channels
[
0
]
for
i
,
up_block_type
in
enumerate
(
up_block_types
):
if
up_block_type
!=
"HunyuanVideoUpBlock3D"
:
raise
ValueError
(
f
"Unsupported up_block_type:
{
up_block_type
}
"
)
prev_output_channel
=
output_channel
output_channel
=
reversed_block_out_channels
[
i
]
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
num_spatial_upsample_layers
=
int
(
np
.
log2
(
spatial_compression_ratio
))
num_time_upsample_layers
=
int
(
np
.
log2
(
time_compression_ratio
))
if
time_compression_ratio
==
4
:
add_spatial_upsample
=
bool
(
i
<
num_spatial_upsample_layers
)
add_time_upsample
=
bool
(
i
>=
len
(
block_out_channels
)
-
1
-
num_time_upsample_layers
and
not
is_final_block
)
else
:
raise
ValueError
(
f
"Unsupported time_compression_ratio:
{
time_compression_ratio
}
"
)
upsample_scale_factor_HW
=
(
2
,
2
)
if
add_spatial_upsample
else
(
1
,
1
)
upsample_scale_factor_T
=
(
2
,
)
if
add_time_upsample
else
(
1
,
)
upsample_scale_factor
=
tuple
(
upsample_scale_factor_T
+
upsample_scale_factor_HW
)
up_block
=
HunyuanVideoUpBlock3D
(
num_layers
=
self
.
layers_per_block
+
1
,
in_channels
=
prev_output_channel
,
out_channels
=
output_channel
,
add_upsample
=
bool
(
add_spatial_upsample
or
add_time_upsample
),
upsample_scale_factor
=
upsample_scale_factor
,
resnet_eps
=
1e-6
,
resnet_act_fn
=
act_fn
,
resnet_groups
=
norm_num_groups
,
)
self
.
up_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
# out
self
.
conv_norm_out
=
nn
.
GroupNorm
(
num_channels
=
block_out_channels
[
0
],
num_groups
=
norm_num_groups
,
eps
=
1e-6
)
self
.
conv_act
=
nn
.
SiLU
()
self
.
conv_out
=
HunyuanVideoCausalConv3d
(
block_out_channels
[
0
],
out_channels
,
kernel_size
=
3
)
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
conv_in
(
hidden_states
)
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
hidden_states
=
self
.
_gradient_checkpointing_func
(
self
.
mid_block
,
hidden_states
)
for
up_block
in
self
.
up_blocks
:
hidden_states
=
self
.
_gradient_checkpointing_func
(
up_block
,
hidden_states
)
else
:
hidden_states
=
self
.
mid_block
(
hidden_states
)
for
up_block
in
self
.
up_blocks
:
hidden_states
=
up_block
(
hidden_states
)
# post-process
hidden_states
=
self
.
conv_norm_out
(
hidden_states
)
hidden_states
=
self
.
conv_act
(
hidden_states
)
hidden_states
=
self
.
conv_out
(
hidden_states
)
return
hidden_states
class
AutoencoderKLHunyuanVideo
(
nn
.
Module
,
ParallelTiledVAE
):
r
"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing
=
True
def
__init__
(
self
,
config
:
HunyuanVAEConfig
,
)
->
None
:
nn
.
Module
.
__init__
(
self
)
ParallelTiledVAE
.
__init__
(
self
,
config
)
# TODO(will): only pass in config. We do this by manually defining a
# config for hunyuan vae
self
.
block_out_channels
=
config
.
block_out_channels
if
config
.
load_encoder
:
self
.
encoder
=
HunyuanVideoEncoder3D
(
in_channels
=
config
.
in_channels
,
out_channels
=
config
.
latent_channels
,
down_block_types
=
config
.
down_block_types
,
block_out_channels
=
config
.
block_out_channels
,
layers_per_block
=
config
.
layers_per_block
,
norm_num_groups
=
config
.
norm_num_groups
,
act_fn
=
config
.
act_fn
,
double_z
=
True
,
mid_block_add_attention
=
config
.
mid_block_add_attention
,
temporal_compression_ratio
=
config
.
temporal_compression_ratio
,
spatial_compression_ratio
=
config
.
spatial_compression_ratio
,
)
self
.
quant_conv
=
nn
.
Conv3d
(
2
*
config
.
latent_channels
,
2
*
config
.
latent_channels
,
kernel_size
=
1
)
if
config
.
load_decoder
:
self
.
decoder
=
HunyuanVideoDecoder3D
(
in_channels
=
config
.
latent_channels
,
out_channels
=
config
.
out_channels
,
up_block_types
=
config
.
up_block_types
,
block_out_channels
=
config
.
block_out_channels
,
layers_per_block
=
config
.
layers_per_block
,
norm_num_groups
=
config
.
norm_num_groups
,
act_fn
=
config
.
act_fn
,
time_compression_ratio
=
config
.
temporal_compression_ratio
,
spatial_compression_ratio
=
config
.
spatial_compression_ratio
,
mid_block_add_attention
=
config
.
mid_block_add_attention
,
)
self
.
post_quant_conv
=
nn
.
Conv3d
(
config
.
latent_channels
,
config
.
latent_channels
,
kernel_size
=
1
)
def
_encode
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
encoder
(
x
)
enc
=
self
.
quant_conv
(
x
)
return
enc
def
_decode
(
self
,
z
:
torch
.
Tensor
)
->
torch
.
Tensor
:
z
=
self
.
post_quant_conv
(
z
)
dec
=
self
.
decoder
(
z
)
return
dec
def
forward
(
self
,
sample
:
torch
.
Tensor
,
sample_posterior
:
bool
=
False
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
)
->
torch
.
Tensor
:
r
"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
"""
x
=
sample
posterior
=
self
.
encode
(
x
).
latent_dist
if
sample_posterior
:
z
=
posterior
.
sample
(
generator
=
generator
)
else
:
z
=
posterior
.
mode
()
dec
=
self
.
decode
(
z
)
return
dec
FastVideo-main/fastvideo/v1/models/vaes/wanvae.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
contextvars
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fastvideo.v1.configs.models.vaes
import
WanVAEConfig
from
fastvideo.v1.layers.activation
import
get_act_fn
from
fastvideo.v1.models.vaes.common
import
(
DiagonalGaussianDistribution
,
ParallelTiledVAE
)
CACHE_T
=
2
is_first_frame
=
contextvars
.
ContextVar
(
"is_first_frame"
,
default
=
False
)
feat_cache
=
contextvars
.
ContextVar
(
"feat_cache"
,
default
=
None
)
feat_idx
=
contextvars
.
ContextVar
(
"feat_idx"
,
default
=
0
)
@
contextmanager
def
forward_context
(
first_frame_arg
=
False
,
feat_cache_arg
=
None
,
feat_idx_arg
=
None
):
is_first_frame_token
=
is_first_frame
.
set
(
first_frame_arg
)
feat_cache_token
=
feat_cache
.
set
(
feat_cache_arg
)
feat_idx_token
=
feat_idx
.
set
(
feat_idx_arg
)
try
:
yield
finally
:
is_first_frame
.
reset
(
is_first_frame_token
)
feat_cache
.
reset
(
feat_cache_token
)
feat_idx
.
reset
(
feat_idx_token
)
class
WanCausalConv3d
(
nn
.
Conv3d
):
r
"""
A custom 3D causal convolution layer with feature caching support.
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
caching for efficient inference.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
Union
[
int
,
Tuple
[
int
,
int
,
int
]],
stride
:
Union
[
int
,
Tuple
[
int
,
int
,
int
]]
=
1
,
padding
:
Union
[
int
,
Tuple
[
int
,
int
,
int
]]
=
0
,
)
->
None
:
super
().
__init__
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
)
self
.
padding
:
Tuple
[
int
,
int
,
int
]
# Set up causal padding
self
.
_padding
:
Tuple
[
int
,
...]
=
(
self
.
padding
[
2
],
self
.
padding
[
2
],
self
.
padding
[
1
],
self
.
padding
[
1
],
2
*
self
.
padding
[
0
],
0
)
self
.
padding
=
(
0
,
0
,
0
)
def
forward
(
self
,
x
,
cache_x
=
None
):
padding
=
list
(
self
.
_padding
)
if
cache_x
is
not
None
and
self
.
_padding
[
4
]
>
0
:
cache_x
=
cache_x
.
to
(
x
.
device
)
x
=
torch
.
cat
([
cache_x
,
x
],
dim
=
2
)
padding
[
4
]
-=
cache_x
.
shape
[
2
]
x
=
F
.
pad
(
x
,
padding
)
return
super
().
forward
(
x
)
class
WanRMS_norm
(
nn
.
Module
):
r
"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def
__init__
(
self
,
dim
:
int
,
channel_first
:
bool
=
True
,
images
:
bool
=
True
,
bias
:
bool
=
False
)
->
None
:
super
().
__init__
()
broadcastable_dims
=
(
1
,
1
,
1
)
if
not
images
else
(
1
,
1
)
shape
=
(
dim
,
*
broadcastable_dims
)
if
channel_first
else
(
dim
,
)
self
.
channel_first
=
channel_first
self
.
scale
=
dim
**
0.5
self
.
gamma
=
nn
.
Parameter
(
torch
.
ones
(
shape
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
shape
))
if
bias
else
0.0
def
forward
(
self
,
x
):
return
F
.
normalize
(
x
,
dim
=
(
1
if
self
.
channel_first
else
-
1
))
*
self
.
scale
*
self
.
gamma
+
self
.
bias
class
WanUpsample
(
nn
.
Upsample
):
r
"""
Perform upsampling while ensuring the output tensor has the same data type as the input.
Args:
x (torch.Tensor): Input tensor to be upsampled.
Returns:
torch.Tensor: Upsampled tensor with the same data type as the input.
"""
def
forward
(
self
,
x
):
return
super
().
forward
(
x
.
float
()).
type_as
(
x
)
class
WanResample
(
nn
.
Module
):
r
"""
A custom resampling module for 2D and 3D data.
Args:
dim (int): The number of input/output channels.
mode (str): The resampling mode. Must be one of:
- 'none': No resampling (identity operation).
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
- 'downsample2d': 2D downsampling with zero-padding and convolution.
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
"""
def
__init__
(
self
,
dim
:
int
,
mode
:
str
)
->
None
:
super
().
__init__
()
self
.
dim
=
dim
self
.
mode
=
mode
# layers
if
mode
==
"upsample2d"
:
self
.
resample
=
nn
.
Sequential
(
WanUpsample
(
scale_factor
=
(
2.0
,
2.0
),
mode
=
"nearest-exact"
),
nn
.
Conv2d
(
dim
,
dim
//
2
,
3
,
padding
=
1
))
elif
mode
==
"upsample3d"
:
self
.
resample
=
nn
.
Sequential
(
WanUpsample
(
scale_factor
=
(
2.0
,
2.0
),
mode
=
"nearest-exact"
),
nn
.
Conv2d
(
dim
,
dim
//
2
,
3
,
padding
=
1
))
self
.
time_conv
=
WanCausalConv3d
(
dim
,
dim
*
2
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
))
elif
mode
==
"downsample2d"
:
self
.
resample
=
nn
.
Sequential
(
nn
.
ZeroPad2d
((
0
,
1
,
0
,
1
)),
nn
.
Conv2d
(
dim
,
dim
,
3
,
stride
=
(
2
,
2
)))
elif
mode
==
"downsample3d"
:
self
.
resample
=
nn
.
Sequential
(
nn
.
ZeroPad2d
((
0
,
1
,
0
,
1
)),
nn
.
Conv2d
(
dim
,
dim
,
3
,
stride
=
(
2
,
2
)))
self
.
time_conv
=
WanCausalConv3d
(
dim
,
dim
,
(
3
,
1
,
1
),
stride
=
(
2
,
1
,
1
),
padding
=
(
0
,
0
,
0
))
else
:
self
.
resample
=
nn
.
Identity
()
def
forward
(
self
,
x
):
b
,
c
,
t
,
h
,
w
=
x
.
size
()
first_frame
=
is_first_frame
.
get
()
if
first_frame
:
assert
t
==
1
_feat_cache
=
feat_cache
.
get
()
_feat_idx
=
feat_idx
.
get
()
if
self
.
mode
==
"upsample3d"
:
if
_feat_cache
is
not
None
:
idx
=
_feat_idx
if
_feat_cache
[
idx
]
is
None
:
_feat_cache
[
idx
]
=
"Rep"
_feat_idx
+=
1
else
:
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
if
cache_x
.
shape
[
2
]
<
2
and
_feat_cache
[
idx
]
is
not
None
and
_feat_cache
[
idx
]
!=
"Rep"
:
# cache last frame of last two chunk
cache_x
=
torch
.
cat
([
_feat_cache
[
idx
][:,
:,
-
1
,
:,
:].
unsqueeze
(
2
).
to
(
cache_x
.
device
),
cache_x
],
dim
=
2
)
if
cache_x
.
shape
[
2
]
<
2
and
_feat_cache
[
idx
]
is
not
None
and
_feat_cache
[
idx
]
==
"Rep"
:
cache_x
=
torch
.
cat
([
torch
.
zeros_like
(
cache_x
).
to
(
cache_x
.
device
),
cache_x
],
dim
=
2
)
if
_feat_cache
[
idx
]
==
"Rep"
:
x
=
self
.
time_conv
(
x
)
else
:
x
=
self
.
time_conv
(
x
,
_feat_cache
[
idx
])
_feat_cache
[
idx
]
=
cache_x
_feat_idx
+=
1
x
=
x
.
reshape
(
b
,
2
,
c
,
t
,
h
,
w
)
x
=
torch
.
stack
((
x
[:,
0
,
:,
:,
:,
:],
x
[:,
1
,
:,
:,
:,
:]),
3
)
x
=
x
.
reshape
(
b
,
c
,
t
*
2
,
h
,
w
)
feat_cache
.
set
(
_feat_cache
)
feat_idx
.
set
(
_feat_idx
)
elif
not
first_frame
and
hasattr
(
self
,
"time_conv"
):
x
=
self
.
time_conv
(
x
)
x
=
x
.
reshape
(
b
,
2
,
c
,
t
,
h
,
w
)
x
=
torch
.
stack
((
x
[:,
0
,
:,
:,
:,
:],
x
[:,
1
,
:,
:,
:,
:]),
3
)
x
=
x
.
reshape
(
b
,
c
,
t
*
2
,
h
,
w
)
t
=
x
.
shape
[
2
]
x
=
x
.
permute
(
0
,
2
,
1
,
3
,
4
).
reshape
(
b
*
t
,
c
,
h
,
w
)
x
=
self
.
resample
(
x
)
x
=
x
.
view
(
b
,
t
,
x
.
size
(
1
),
x
.
size
(
2
),
x
.
size
(
3
)).
permute
(
0
,
2
,
1
,
3
,
4
)
_feat_cache
=
feat_cache
.
get
()
_feat_idx
=
feat_idx
.
get
()
if
self
.
mode
==
"downsample3d"
:
if
_feat_cache
is
not
None
:
idx
=
_feat_idx
if
_feat_cache
[
idx
]
is
None
:
_feat_cache
[
idx
]
=
x
.
clone
()
_feat_idx
+=
1
else
:
cache_x
=
x
[:,
:,
-
1
:,
:,
:].
clone
()
x
=
self
.
time_conv
(
torch
.
cat
([
_feat_cache
[
idx
][:,
:,
-
1
:,
:,
:],
x
],
2
))
_feat_cache
[
idx
]
=
cache_x
_feat_idx
+=
1
feat_cache
.
set
(
_feat_cache
)
feat_idx
.
set
(
_feat_idx
)
elif
not
first_frame
and
hasattr
(
self
,
"time_conv"
):
x
=
self
.
time_conv
(
x
)
return
x
class
WanResidualBlock
(
nn
.
Module
):
r
"""
A custom residual block module.
Args:
in_dim (int): Number of input channels.
out_dim (int): Number of output channels.
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
"""
def
__init__
(
self
,
in_dim
:
int
,
out_dim
:
int
,
dropout
:
float
=
0.0
,
non_linearity
:
str
=
"silu"
,
)
->
None
:
super
().
__init__
()
self
.
in_dim
=
in_dim
self
.
out_dim
=
out_dim
self
.
nonlinearity
=
get_act_fn
(
non_linearity
)
# layers
self
.
norm1
=
WanRMS_norm
(
in_dim
,
images
=
False
)
self
.
conv1
=
WanCausalConv3d
(
in_dim
,
out_dim
,
3
,
padding
=
1
)
self
.
norm2
=
WanRMS_norm
(
out_dim
,
images
=
False
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
conv2
=
WanCausalConv3d
(
out_dim
,
out_dim
,
3
,
padding
=
1
)
self
.
conv_shortcut
=
WanCausalConv3d
(
in_dim
,
out_dim
,
1
)
if
in_dim
!=
out_dim
else
nn
.
Identity
()
def
forward
(
self
,
x
):
# Apply shortcut connection
h
=
self
.
conv_shortcut
(
x
)
# First normalization and activation
x
=
self
.
norm1
(
x
)
x
=
self
.
nonlinearity
(
x
)
_feat_cache
=
feat_cache
.
get
()
_feat_idx
=
feat_idx
.
get
()
if
_feat_cache
is
not
None
:
idx
=
_feat_idx
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
if
cache_x
.
shape
[
2
]
<
2
and
_feat_cache
[
idx
]
is
not
None
:
cache_x
=
torch
.
cat
([
_feat_cache
[
idx
][:,
:,
-
1
,
:,
:].
unsqueeze
(
2
).
to
(
cache_x
.
device
),
cache_x
],
dim
=
2
)
x
=
self
.
conv1
(
x
,
_feat_cache
[
idx
])
_feat_cache
[
idx
]
=
cache_x
_feat_idx
+=
1
feat_cache
.
set
(
_feat_cache
)
feat_idx
.
set
(
_feat_idx
)
else
:
x
=
self
.
conv1
(
x
)
# Second normalization and activation
x
=
self
.
norm2
(
x
)
x
=
self
.
nonlinearity
(
x
)
# Dropout
x
=
self
.
dropout
(
x
)
_feat_cache
=
feat_cache
.
get
()
_feat_idx
=
feat_idx
.
get
()
if
_feat_cache
is
not
None
:
idx
=
_feat_idx
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
if
cache_x
.
shape
[
2
]
<
2
and
_feat_cache
[
idx
]
is
not
None
:
cache_x
=
torch
.
cat
([
_feat_cache
[
idx
][:,
:,
-
1
,
:,
:].
unsqueeze
(
2
).
to
(
cache_x
.
device
),
cache_x
],
dim
=
2
)
x
=
self
.
conv2
(
x
,
_feat_cache
[
idx
])
_feat_cache
[
idx
]
=
cache_x
_feat_idx
+=
1
feat_cache
.
set
(
_feat_cache
)
feat_idx
.
set
(
_feat_idx
)
else
:
x
=
self
.
conv2
(
x
)
# Add residual connection
return
x
+
h
class
WanAttentionBlock
(
nn
.
Module
):
r
"""
Causal self-attention with a single head.
Args:
dim (int): The number of channels in the input tensor.
"""
def
__init__
(
self
,
dim
)
->
None
:
super
().
__init__
()
self
.
dim
=
dim
# layers
self
.
norm
=
WanRMS_norm
(
dim
)
self
.
to_qkv
=
nn
.
Conv2d
(
dim
,
dim
*
3
,
1
)
self
.
proj
=
nn
.
Conv2d
(
dim
,
dim
,
1
)
def
forward
(
self
,
x
):
identity
=
x
batch_size
,
channels
,
time
,
height
,
width
=
x
.
size
()
x
=
x
.
permute
(
0
,
2
,
1
,
3
,
4
).
reshape
(
batch_size
*
time
,
channels
,
height
,
width
)
x
=
self
.
norm
(
x
)
# compute query, key, value
qkv
=
self
.
to_qkv
(
x
)
qkv
=
qkv
.
reshape
(
batch_size
*
time
,
1
,
channels
*
3
,
-
1
)
qkv
=
qkv
.
permute
(
0
,
1
,
3
,
2
).
contiguous
()
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=-
1
)
# apply attention
x
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
)
x
=
x
.
squeeze
(
1
).
permute
(
0
,
2
,
1
).
reshape
(
batch_size
*
time
,
channels
,
height
,
width
)
# output projection
x
=
self
.
proj
(
x
)
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
x
=
x
.
view
(
batch_size
,
time
,
channels
,
height
,
width
)
x
=
x
.
permute
(
0
,
2
,
1
,
3
,
4
)
return
x
+
identity
class
WanMidBlock
(
nn
.
Module
):
"""
Middle block for WanVAE encoder and decoder.
Args:
dim (int): Number of input/output channels.
dropout (float): Dropout rate.
non_linearity (str): Type of non-linearity to use.
"""
def
__init__
(
self
,
dim
:
int
,
dropout
:
float
=
0.0
,
non_linearity
:
str
=
"silu"
,
num_layers
:
int
=
1
):
super
().
__init__
()
self
.
dim
=
dim
# Create the components
resnets
=
[
WanResidualBlock
(
dim
,
dim
,
dropout
,
non_linearity
)]
attentions
=
[]
for
_
in
range
(
num_layers
):
attentions
.
append
(
WanAttentionBlock
(
dim
))
resnets
.
append
(
WanResidualBlock
(
dim
,
dim
,
dropout
,
non_linearity
))
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
x
):
# First residual block
x
=
self
.
resnets
[
0
](
x
)
# Process through attention and residual blocks
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
if
attn
is
not
None
:
x
=
attn
(
x
)
x
=
resnet
(
x
)
return
x
class
WanEncoder3d
(
nn
.
Module
):
r
"""
A 3D encoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_downsample (list of bool): Whether to downsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def
__init__
(
self
,
dim
=
128
,
z_dim
=
4
,
dim_mult
=
(
1
,
2
,
4
,
4
),
num_res_blocks
=
2
,
attn_scales
=
(),
temperal_downsample
=
(
True
,
True
,
False
),
dropout
=
0.0
,
non_linearity
:
str
=
"silu"
,
):
super
().
__init__
()
self
.
dim
=
dim
self
.
z_dim
=
z_dim
dim_mult
=
list
(
dim_mult
)
self
.
dim_mult
=
dim_mult
self
.
num_res_blocks
=
num_res_blocks
self
.
attn_scales
=
list
(
attn_scales
)
self
.
temperal_downsample
=
list
(
temperal_downsample
)
self
.
nonlinearity
=
get_act_fn
(
non_linearity
)
# dimensions
dims
=
[
dim
*
u
for
u
in
[
1
]
+
dim_mult
]
scale
=
1.0
# init block
self
.
conv_in
=
WanCausalConv3d
(
3
,
dims
[
0
],
3
,
padding
=
1
)
# downsample blocks
self
.
down_blocks
=
nn
.
ModuleList
([])
for
i
,
(
in_dim
,
out_dim
)
in
enumerate
(
zip
(
dims
[:
-
1
],
dims
[
1
:])):
# residual (+attention) blocks
for
_
in
range
(
num_res_blocks
):
self
.
down_blocks
.
append
(
WanResidualBlock
(
in_dim
,
out_dim
,
dropout
))
if
scale
in
attn_scales
:
self
.
down_blocks
.
append
(
WanAttentionBlock
(
out_dim
))
in_dim
=
out_dim
# downsample block
if
i
!=
len
(
dim_mult
)
-
1
:
mode
=
"downsample3d"
if
temperal_downsample
[
i
]
else
"downsample2d"
self
.
down_blocks
.
append
(
WanResample
(
out_dim
,
mode
=
mode
))
scale
/=
2.0
# middle blocks
self
.
mid_block
=
WanMidBlock
(
out_dim
,
dropout
,
non_linearity
,
num_layers
=
1
)
# output blocks
self
.
norm_out
=
WanRMS_norm
(
out_dim
,
images
=
False
)
self
.
conv_out
=
WanCausalConv3d
(
out_dim
,
z_dim
,
3
,
padding
=
1
)
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
x
):
_feat_cache
=
feat_cache
.
get
()
_feat_idx
=
feat_idx
.
get
()
if
_feat_cache
is
not
None
:
idx
=
_feat_idx
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
if
cache_x
.
shape
[
2
]
<
2
and
_feat_cache
[
idx
]
is
not
None
:
# cache last frame of last two chunk
cache_x
=
torch
.
cat
([
_feat_cache
[
idx
][:,
:,
-
1
,
:,
:].
unsqueeze
(
2
).
to
(
cache_x
.
device
),
cache_x
],
dim
=
2
)
x
=
self
.
conv_in
(
x
,
_feat_cache
[
idx
])
_feat_cache
[
idx
]
=
cache_x
_feat_idx
+=
1
feat_cache
.
set
(
_feat_cache
)
feat_idx
.
set
(
_feat_idx
)
else
:
x
=
self
.
conv_in
(
x
)
## downsamples
for
layer
in
self
.
down_blocks
:
x
=
layer
(
x
)
## middle
x
=
self
.
mid_block
(
x
)
## head
x
=
self
.
norm_out
(
x
)
x
=
self
.
nonlinearity
(
x
)
_feat_cache
=
feat_cache
.
get
()
_feat_idx
=
feat_idx
.
get
()
if
_feat_cache
is
not
None
:
idx
=
_feat_idx
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
if
cache_x
.
shape
[
2
]
<
2
and
_feat_cache
[
idx
]
is
not
None
:
# cache last frame of last two chunk
cache_x
=
torch
.
cat
([
_feat_cache
[
idx
][:,
:,
-
1
,
:,
:].
unsqueeze
(
2
).
to
(
cache_x
.
device
),
cache_x
],
dim
=
2
)
x
=
self
.
conv_out
(
x
,
_feat_cache
[
idx
])
_feat_cache
[
idx
]
=
cache_x
_feat_idx
+=
1
feat_cache
.
set
(
_feat_cache
)
feat_idx
.
set
(
_feat_idx
)
else
:
x
=
self
.
conv_out
(
x
)
return
x
class
WanUpBlock
(
nn
.
Module
):
"""
A block that handles upsampling for the WanVAE decoder.
Args:
in_dim (int): Input dimension
out_dim (int): Output dimension
num_res_blocks (int): Number of residual blocks
dropout (float): Dropout rate
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
non_linearity (str): Type of non-linearity to use
"""
def
__init__
(
self
,
in_dim
:
int
,
out_dim
:
int
,
num_res_blocks
:
int
,
dropout
:
float
=
0.0
,
upsample_mode
:
Optional
[
str
]
=
None
,
non_linearity
:
str
=
"silu"
,
):
super
().
__init__
()
self
.
in_dim
=
in_dim
self
.
out_dim
=
out_dim
# Create layers list
resnets
=
[]
# Add residual blocks and attention if needed
current_dim
=
in_dim
for
_
in
range
(
num_res_blocks
+
1
):
resnets
.
append
(
WanResidualBlock
(
current_dim
,
out_dim
,
dropout
,
non_linearity
))
current_dim
=
out_dim
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
# Add upsampling layer if needed
self
.
upsamplers
=
None
if
upsample_mode
is
not
None
:
self
.
upsamplers
=
nn
.
ModuleList
(
[
WanResample
(
out_dim
,
mode
=
upsample_mode
)])
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
x
):
"""
Forward pass through the upsampling block.
Args:
x (torch.Tensor): Input tensor
feat_cache (list, optional): Feature cache for causal convolutions
feat_idx (list, optional): Feature index for cache management
Returns:
torch.Tensor: Output tensor
"""
for
resnet
in
self
.
resnets
:
x
=
resnet
(
x
)
if
self
.
upsamplers
is
not
None
:
x
=
self
.
upsamplers
[
0
](
x
)
return
x
class
WanDecoder3d
(
nn
.
Module
):
r
"""
A 3D decoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_upsample (list of bool): Whether to upsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def
__init__
(
self
,
dim
=
128
,
z_dim
=
4
,
dim_mult
=
(
1
,
2
,
4
,
4
),
num_res_blocks
=
2
,
attn_scales
=
(),
temperal_upsample
=
(
False
,
True
,
True
),
dropout
=
0.0
,
non_linearity
:
str
=
"silu"
,
):
super
().
__init__
()
self
.
dim
=
dim
self
.
z_dim
=
z_dim
dim_mult
=
list
(
dim_mult
)
self
.
dim_mult
=
dim_mult
self
.
num_res_blocks
=
num_res_blocks
self
.
attn_scales
=
list
(
attn_scales
)
self
.
temperal_upsample
=
list
(
temperal_upsample
)
self
.
nonlinearity
=
get_act_fn
(
non_linearity
)
# dimensions
dims
=
[
dim
*
u
for
u
in
[
dim_mult
[
-
1
]]
+
dim_mult
[::
-
1
]]
scale
=
1.0
/
2
**
(
len
(
dim_mult
)
-
2
)
# init block
self
.
conv_in
=
WanCausalConv3d
(
z_dim
,
dims
[
0
],
3
,
padding
=
1
)
# middle blocks
self
.
mid_block
=
WanMidBlock
(
dims
[
0
],
dropout
,
non_linearity
,
num_layers
=
1
)
# upsample blocks
self
.
up_blocks
=
nn
.
ModuleList
([])
for
i
,
(
in_dim
,
out_dim
)
in
enumerate
(
zip
(
dims
[:
-
1
],
dims
[
1
:])):
# residual (+attention) blocks
if
i
>
0
:
in_dim
=
in_dim
//
2
# Determine if we need upsampling
upsample_mode
=
None
if
i
!=
len
(
dim_mult
)
-
1
:
upsample_mode
=
"upsample3d"
if
temperal_upsample
[
i
]
else
"upsample2d"
# Create and add the upsampling block
up_block
=
WanUpBlock
(
in_dim
=
in_dim
,
out_dim
=
out_dim
,
num_res_blocks
=
num_res_blocks
,
dropout
=
dropout
,
upsample_mode
=
upsample_mode
,
non_linearity
=
non_linearity
,
)
self
.
up_blocks
.
append
(
up_block
)
# Update scale for next iteration
if
upsample_mode
is
not
None
:
scale
*=
2.0
# output blocks
self
.
norm_out
=
WanRMS_norm
(
out_dim
,
images
=
False
)
self
.
conv_out
=
WanCausalConv3d
(
out_dim
,
3
,
3
,
padding
=
1
)
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
x
):
## conv1
_feat_cache
=
feat_cache
.
get
()
_feat_idx
=
feat_idx
.
get
()
if
_feat_cache
is
not
None
:
idx
=
_feat_idx
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
if
cache_x
.
shape
[
2
]
<
2
and
_feat_cache
[
idx
]
is
not
None
:
# cache last frame of last two chunk
cache_x
=
torch
.
cat
([
_feat_cache
[
idx
][:,
:,
-
1
,
:,
:].
unsqueeze
(
2
).
to
(
cache_x
.
device
),
cache_x
],
dim
=
2
)
x
=
self
.
conv_in
(
x
,
_feat_cache
[
idx
])
_feat_cache
[
idx
]
=
cache_x
_feat_idx
+=
1
feat_cache
.
set
(
_feat_cache
)
feat_idx
.
set
(
_feat_idx
)
else
:
x
=
self
.
conv_in
(
x
)
## middle
x
=
self
.
mid_block
(
x
)
## upsamples
for
up_block
in
self
.
up_blocks
:
x
=
up_block
(
x
)
## head
x
=
self
.
norm_out
(
x
)
x
=
self
.
nonlinearity
(
x
)
_feat_cache
=
feat_cache
.
get
()
_feat_idx
=
feat_idx
.
get
()
if
_feat_cache
is
not
None
:
idx
=
_feat_idx
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
if
cache_x
.
shape
[
2
]
<
2
and
_feat_cache
[
idx
]
is
not
None
:
# cache last frame of last two chunk
cache_x
=
torch
.
cat
([
_feat_cache
[
idx
][:,
:,
-
1
,
:,
:].
unsqueeze
(
2
).
to
(
cache_x
.
device
),
cache_x
],
dim
=
2
)
x
=
self
.
conv_out
(
x
,
_feat_cache
[
idx
])
_feat_cache
[
idx
]
=
cache_x
_feat_idx
+=
1
feat_cache
.
set
(
_feat_cache
)
feat_idx
.
set
(
_feat_idx
)
else
:
x
=
self
.
conv_out
(
x
)
return
x
class
AutoencoderKLWan
(
nn
.
Module
,
ParallelTiledVAE
):
r
"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1].
"""
_supports_gradient_checkpointing
=
False
def
__init__
(
self
,
config
:
WanVAEConfig
,
)
->
None
:
nn
.
Module
.
__init__
(
self
)
ParallelTiledVAE
.
__init__
(
self
,
config
)
self
.
z_dim
=
config
.
z_dim
self
.
temperal_downsample
=
list
(
config
.
temperal_downsample
)
self
.
temperal_upsample
=
list
(
config
.
temperal_downsample
)[::
-
1
]
self
.
latents_mean
=
list
(
config
.
latents_mean
)
self
.
latents_std
=
list
(
config
.
latents_std
)
self
.
shift_factor
=
config
.
shift_factor
if
config
.
load_encoder
:
self
.
encoder
=
WanEncoder3d
(
config
.
base_dim
,
self
.
z_dim
*
2
,
config
.
dim_mult
,
config
.
num_res_blocks
,
config
.
attn_scales
,
self
.
temperal_downsample
,
config
.
dropout
)
self
.
quant_conv
=
WanCausalConv3d
(
self
.
z_dim
*
2
,
self
.
z_dim
*
2
,
1
)
self
.
post_quant_conv
=
WanCausalConv3d
(
self
.
z_dim
,
self
.
z_dim
,
1
)
if
config
.
load_decoder
:
self
.
decoder
=
WanDecoder3d
(
config
.
base_dim
,
self
.
z_dim
,
config
.
dim_mult
,
config
.
num_res_blocks
,
config
.
attn_scales
,
self
.
temperal_upsample
,
config
.
dropout
)
self
.
use_feature_cache
=
config
.
use_feature_cache
def
clear_cache
(
self
)
->
None
:
def
_count_conv3d
(
model
)
->
int
:
count
=
0
for
m
in
model
.
modules
():
if
isinstance
(
m
,
WanCausalConv3d
):
count
+=
1
return
count
if
self
.
config
.
load_decoder
:
self
.
_conv_num
=
_count_conv3d
(
self
.
decoder
)
self
.
_conv_idx
=
0
self
.
_feat_map
=
[
None
]
*
self
.
_conv_num
# cache encode
if
self
.
config
.
load_encoder
:
self
.
_enc_conv_num
=
_count_conv3d
(
self
.
encoder
)
self
.
_enc_conv_idx
=
0
self
.
_enc_feat_map
=
[
None
]
*
self
.
_enc_conv_num
def
encode
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
use_feature_cache
:
self
.
clear_cache
()
with
forward_context
(
feat_cache_arg
=
self
.
_enc_feat_map
,
feat_idx_arg
=
self
.
_enc_conv_idx
):
t
=
x
.
shape
[
2
]
iter_
=
1
+
(
t
-
1
)
//
4
for
i
in
range
(
iter_
):
feat_idx
.
set
(
0
)
if
i
==
0
:
out
=
self
.
encoder
(
x
[:,
:,
:
1
,
:,
:])
else
:
out_
=
self
.
encoder
(
x
[:,
:,
1
+
4
*
(
i
-
1
):
1
+
4
*
i
,
:,
:])
out
=
torch
.
cat
([
out
,
out_
],
2
)
enc
=
self
.
quant_conv
(
out
)
mu
,
logvar
=
enc
[:,
:
self
.
z_dim
,
:,
:,
:],
enc
[:,
self
.
z_dim
:,
:,
:,
:]
enc
=
torch
.
cat
([
mu
,
logvar
],
dim
=
1
)
enc
=
DiagonalGaussianDistribution
(
enc
)
self
.
clear_cache
()
else
:
for
block
in
self
.
encoder
.
down_blocks
:
if
isinstance
(
block
,
WanResample
)
and
block
.
mode
==
"downsample3d"
:
_padding
=
list
(
block
.
time_conv
.
_padding
)
_padding
[
4
]
=
2
block
.
time_conv
.
_padding
=
tuple
(
_padding
)
enc
=
ParallelTiledVAE
.
encode
(
self
,
x
)
return
enc
def
_encode
(
self
,
x
:
torch
.
Tensor
,
first_frame
=
False
)
->
torch
.
Tensor
:
with
forward_context
(
first_frame_arg
=
first_frame
):
out
=
self
.
encoder
(
x
)
enc
=
self
.
quant_conv
(
out
)
mu
,
logvar
=
enc
[:,
:
self
.
z_dim
,
:,
:,
:],
enc
[:,
self
.
z_dim
:,
:,
:,
:]
enc
=
torch
.
cat
([
mu
,
logvar
],
dim
=
1
)
return
enc
def
tiled_encode
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
first_frame
=
x
[:,
:,
0
,
:,
:].
unsqueeze
(
2
)
first_frame
=
self
.
_encode
(
first_frame
,
first_frame
=
True
)
enc
=
ParallelTiledVAE
.
tiled_encode
(
self
,
x
)
enc
=
enc
[:,
:,
1
:]
enc
=
torch
.
cat
([
first_frame
,
enc
],
dim
=
2
)
return
enc
def
spatial_tiled_encode
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
first_frame
=
x
[:,
:,
0
,
:,
:].
unsqueeze
(
2
)
first_frame
=
self
.
_encode
(
first_frame
,
first_frame
=
True
)
enc
=
ParallelTiledVAE
.
spatial_tiled_encode
(
self
,
x
)
enc
=
enc
[:,
:,
1
:]
enc
=
torch
.
cat
([
first_frame
,
enc
],
dim
=
2
)
return
enc
def
decode
(
self
,
z
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
use_feature_cache
:
self
.
clear_cache
()
iter_
=
z
.
shape
[
2
]
x
=
self
.
post_quant_conv
(
z
)
with
forward_context
(
feat_cache_arg
=
self
.
_feat_map
,
feat_idx_arg
=
self
.
_conv_idx
):
for
i
in
range
(
iter_
):
feat_idx
.
set
(
0
)
if
i
==
0
:
out
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:])
else
:
out_
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:])
out
=
torch
.
cat
([
out
,
out_
],
2
)
out
=
torch
.
clamp
(
out
,
min
=-
1.0
,
max
=
1.0
)
self
.
clear_cache
()
else
:
out
=
ParallelTiledVAE
.
decode
(
self
,
z
)
return
out
def
_decode
(
self
,
z
:
torch
.
Tensor
,
first_frame
=
False
)
->
torch
.
Tensor
:
x
=
self
.
post_quant_conv
(
z
)
with
forward_context
(
first_frame_arg
=
first_frame
):
out
=
self
.
decoder
(
x
)
out
=
torch
.
clamp
(
out
,
min
=-
1.0
,
max
=
1.0
)
return
out
def
tiled_decode
(
self
,
z
:
torch
.
Tensor
)
->
torch
.
Tensor
:
self
.
blend_num_frames
*=
2
dec
=
ParallelTiledVAE
.
tiled_decode
(
self
,
z
)
start_frame_idx
=
self
.
temporal_compression_ratio
-
1
dec
=
dec
[:,
:,
start_frame_idx
:]
return
dec
def
spatial_tiled_decode
(
self
,
z
:
torch
.
Tensor
)
->
torch
.
Tensor
:
dec
=
ParallelTiledVAE
.
spatial_tiled_decode
(
self
,
z
)
start_frame_idx
=
self
.
temporal_compression_ratio
-
1
dec
=
dec
[:,
:,
start_frame_idx
:]
return
dec
def
parallel_tiled_decode
(
self
,
z
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
self
.
blend_num_frames
*=
2
dec
=
ParallelTiledVAE
.
parallel_tiled_decode
(
self
,
z
)
start_frame_idx
=
self
.
temporal_compression_ratio
-
1
dec
=
dec
[:,
:,
start_frame_idx
:]
return
dec
def
forward
(
self
,
sample
:
torch
.
Tensor
,
sample_posterior
:
bool
=
False
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x
=
sample
posterior
=
self
.
encode
(
x
).
latent_dist
if
sample_posterior
:
z
=
posterior
.
sample
(
generator
=
generator
)
else
:
z
=
posterior
.
mode
()
dec
=
self
.
decode
(
z
)
return
dec
FastVideo-main/fastvideo/v1/models/vision_utils.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
PIL.Image
import
PIL.ImageOps
import
requests
import
torch
from
packaging
import
version
if
version
.
parse
(
version
.
parse
(
PIL
.
__version__
).
base_version
)
>=
version
.
parse
(
"9.1.0"
):
PIL_INTERPOLATION
=
{
"linear"
:
PIL
.
Image
.
Resampling
.
BILINEAR
,
"bilinear"
:
PIL
.
Image
.
Resampling
.
BILINEAR
,
"bicubic"
:
PIL
.
Image
.
Resampling
.
BICUBIC
,
"lanczos"
:
PIL
.
Image
.
Resampling
.
LANCZOS
,
"nearest"
:
PIL
.
Image
.
Resampling
.
NEAREST
,
}
else
:
PIL_INTERPOLATION
=
{
"linear"
:
PIL
.
Image
.
LINEAR
,
"bilinear"
:
PIL
.
Image
.
BILINEAR
,
"bicubic"
:
PIL
.
Image
.
BICUBIC
,
"lanczos"
:
PIL
.
Image
.
LANCZOS
,
"nearest"
:
PIL
.
Image
.
NEAREST
,
}
def
pil_to_numpy
(
images
:
Union
[
List
[
PIL
.
Image
.
Image
],
PIL
.
Image
.
Image
])
->
np
.
ndarray
:
r
"""
Convert a PIL image or a list of PIL images to NumPy arrays.
Args:
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
The PIL image or list of images to convert to NumPy format.
Returns:
`np.ndarray`:
A NumPy array representation of the images.
"""
if
not
isinstance
(
images
,
list
):
images
=
[
images
]
images
=
[
np
.
array
(
image
).
astype
(
np
.
float32
)
/
255.0
for
image
in
images
]
images_arr
:
np
.
ndarray
=
np
.
stack
(
images
,
axis
=
0
)
return
images_arr
def
numpy_to_pt
(
images
:
np
.
ndarray
)
->
torch
.
Tensor
:
r
"""
Convert a NumPy image to a PyTorch tensor.
Args:
images (`np.ndarray`):
The NumPy image array to convert to PyTorch format.
Returns:
`torch.Tensor`:
A PyTorch tensor representation of the images.
"""
if
images
.
ndim
==
3
:
images
=
images
[...,
None
]
images
=
torch
.
from_numpy
(
images
.
transpose
(
0
,
3
,
1
,
2
))
return
images
def
normalize
(
images
:
Union
[
np
.
ndarray
,
torch
.
Tensor
])
->
Union
[
np
.
ndarray
,
torch
.
Tensor
]:
r
"""
Normalize an image array to [-1,1].
Args:
images (`np.ndarray` or `torch.Tensor`):
The image array to normalize.
Returns:
`np.ndarray` or `torch.Tensor`:
The normalized image array.
"""
return
2.0
*
images
-
1.0
def
load_image
(
image
:
Union
[
str
,
PIL
.
Image
.
Image
],
convert_method
:
Optional
[
Callable
[[
PIL
.
Image
.
Image
],
PIL
.
Image
.
Image
]]
=
None
)
->
PIL
.
Image
.
Image
:
"""
Loads `image` to a PIL Image.
Args:
image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format.
convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
A conversion method to apply to the image after loading it. When set to `None` the image will be converted
"RGB".
Returns:
`PIL.Image.Image`:
A PIL Image.
"""
if
isinstance
(
image
,
str
):
if
image
.
startswith
(
"http://"
)
or
image
.
startswith
(
"https://"
):
image
=
PIL
.
Image
.
open
(
requests
.
get
(
image
,
stream
=
True
).
raw
)
elif
os
.
path
.
isfile
(
image
):
image
=
PIL
.
Image
.
open
(
image
)
else
:
raise
ValueError
(
f
"Incorrect path or URL. URLs must start with `http://` or `https://`, and
{
image
}
is not a valid path."
)
elif
isinstance
(
image
,
PIL
.
Image
.
Image
):
image
=
image
else
:
raise
ValueError
(
"Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
)
image
=
PIL
.
ImageOps
.
exif_transpose
(
image
)
if
convert_method
is
not
None
:
image
=
convert_method
(
image
)
else
:
image
=
image
.
convert
(
"RGB"
)
return
image
def
get_default_height_width
(
image
:
Union
[
PIL
.
Image
.
Image
,
np
.
ndarray
,
torch
.
Tensor
],
vae_scale_factor
:
int
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
int
,
int
]:
r
"""
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
Args:
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
tensor, it should have shape `[batch, channels, height, width]`.
height (`Optional[int]`, *optional*, defaults to `None`):
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
width (`Optional[int]`, *optional*, defaults to `None`):
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
Returns:
`Tuple[int, int]`:
A tuple containing the height and width, both resized to the nearest integer multiple of
`vae_scale_factor`.
"""
if
height
is
None
:
if
isinstance
(
image
,
PIL
.
Image
.
Image
):
height
=
image
.
height
elif
isinstance
(
image
,
torch
.
Tensor
):
height
=
image
.
shape
[
2
]
else
:
height
=
image
.
shape
[
1
]
if
width
is
None
:
if
isinstance
(
image
,
PIL
.
Image
.
Image
):
width
=
image
.
width
elif
isinstance
(
image
,
torch
.
Tensor
):
width
=
image
.
shape
[
3
]
else
:
width
=
image
.
shape
[
2
]
width
,
height
=
(
x
-
x
%
vae_scale_factor
for
x
in
(
width
,
height
)
)
# resize to integer multiple of vae_scale_factor
return
height
,
width
def
resize
(
image
:
Union
[
PIL
.
Image
.
Image
,
np
.
ndarray
,
torch
.
Tensor
],
height
:
int
,
width
:
int
,
resize_mode
:
str
=
"default"
,
# "default", "fill", "crop"
resample
:
str
=
"lanczos"
,
)
->
Union
[
PIL
.
Image
.
Image
,
np
.
ndarray
,
torch
.
Tensor
]:
"""
Resize image.
Args:
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
The image input, can be a PIL image, numpy array or pytorch tensor.
height (`int`):
The height to resize to.
width (`int`):
The width to resize to.
resize_mode (`str`, *optional*, defaults to `default`):
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
supported for PIL image input.
Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
The resized image.
"""
if
resize_mode
!=
"default"
and
not
isinstance
(
image
,
PIL
.
Image
.
Image
):
raise
ValueError
(
f
"Only PIL image input is supported for resize_mode
{
resize_mode
}
"
)
assert
isinstance
(
image
,
PIL
.
Image
.
Image
)
if
resize_mode
==
"default"
:
image
=
image
.
resize
((
width
,
height
),
resample
=
PIL_INTERPOLATION
[
resample
])
else
:
raise
ValueError
(
f
"resize_mode
{
resize_mode
}
is not supported"
)
return
image
FastVideo-main/fastvideo/v1/pipelines/README.md
0 → 100644
View file @
c07946d8
# Adding a New Custom Pipeline
Please see
FastVideo-main/fastvideo/v1/pipelines/__init__.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
"""
Diffusion pipelines for fastvideo.v1.
This package contains diffusion pipelines for generating videos and images.
"""
from
fastvideo.v1.fastvideo_args
import
FastVideoArgs
from
fastvideo.v1.logger
import
init_logger
from
fastvideo.v1.pipelines.composed_pipeline_base
import
ComposedPipelineBase
from
fastvideo.v1.pipelines.pipeline_batch_info
import
ForwardBatch
from
fastvideo.v1.pipelines.pipeline_registry
import
PipelineRegistry
from
fastvideo.v1.utils
import
(
maybe_download_model
,
verify_model_config_and_directory
)
logger
=
init_logger
(
__name__
)
def
build_pipeline
(
fastvideo_args
:
FastVideoArgs
)
->
ComposedPipelineBase
:
"""
Only works with valid hf diffusers configs. (model_index.json)
We want to build a pipeline based on the inference args mode_path:
1. download the model from the hub if it's not already downloaded
2. verify the model config and directory
3. based on the config, determine the pipeline class
"""
# Get pipeline type
model_path
=
fastvideo_args
.
model_path
model_path
=
maybe_download_model
(
model_path
)
# fastvideo_args.downloaded_model_path = model_path
logger
.
info
(
"Model path: %s"
,
model_path
)
config
=
verify_model_config_and_directory
(
model_path
)
pipeline_architecture
=
config
.
get
(
"_class_name"
)
if
pipeline_architecture
is
None
:
raise
ValueError
(
"Model config does not contain a _class_name attribute. "
"Only diffusers format is supported."
)
pipeline_cls
,
pipeline_architecture
=
PipelineRegistry
.
resolve_pipeline_cls
(
pipeline_architecture
)
# instantiate the pipeline
pipeline
=
pipeline_cls
(
model_path
,
fastvideo_args
,
config
)
logger
.
info
(
"Pipeline instantiated"
)
# pipeline is now initialized and ready to use
return
pipeline
__all__
=
[
"build_pipeline"
,
"list_available_pipelines"
,
"ComposedPipelineBase"
,
"PipelineRegistry"
,
"ForwardBatch"
,
]
FastVideo-main/fastvideo/v1/pipelines/composed_pipeline_base.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
"""
Base class for composed pipelines.
This module defines the base class for pipelines that are composed of multiple stages.
"""
import
os
from
abc
import
ABC
,
abstractmethod
from
copy
import
deepcopy
from
typing
import
Any
,
Dict
,
List
,
Optional
,
cast
import
torch
from
fastvideo.v1.fastvideo_args
import
FastVideoArgs
from
fastvideo.v1.logger
import
init_logger
from
fastvideo.v1.models.loader.component_loader
import
PipelineComponentLoader
from
fastvideo.v1.pipelines.pipeline_batch_info
import
ForwardBatch
from
fastvideo.v1.pipelines.stages
import
PipelineStage
from
fastvideo.v1.utils
import
(
maybe_download_model
,
verify_model_config_and_directory
)
logger
=
init_logger
(
__name__
)
class
ComposedPipelineBase
(
ABC
):
"""
Base class for pipelines composed of multiple stages.
This class provides the framework for creating pipelines by composing multiple
stages together. Each stage is responsible for a specific part of the diffusion
process, and the pipeline orchestrates the execution of these stages.
"""
is_video_pipeline
:
bool
=
False
# To be overridden by video pipelines
_required_config_modules
:
List
[
str
]
=
[]
# TODO(will): args should support both inference args and training args
def
__init__
(
self
,
model_path
:
str
,
fastvideo_args
:
FastVideoArgs
,
config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
"""
Initialize the pipeline. After __init__, the pipeline should be ready to
use. The pipeline should be stateless and not hold any batch state.
"""
self
.
model_path
=
model_path
self
.
_stages
:
List
[
PipelineStage
]
=
[]
self
.
_stage_name_mapping
:
Dict
[
str
,
PipelineStage
]
=
{}
if
self
.
_required_config_modules
is
None
:
raise
NotImplementedError
(
"Subclass must set _required_config_modules"
)
if
config
is
None
:
# Load configuration
logger
.
info
(
"Loading pipeline configuration..."
)
self
.
config
=
self
.
_load_config
(
model_path
)
else
:
self
.
config
=
config
# Load modules directly in initialization
logger
.
info
(
"Loading pipeline modules..."
)
self
.
modules
=
self
.
load_modules
(
fastvideo_args
)
self
.
initialize_pipeline
(
fastvideo_args
)
logger
.
info
(
"Creating pipeline stages..."
)
self
.
create_pipeline_stages
(
fastvideo_args
)
def
get_module
(
self
,
module_name
:
str
)
->
Any
:
return
self
.
modules
[
module_name
]
def
add_module
(
self
,
module_name
:
str
,
module
:
Any
):
self
.
modules
[
module_name
]
=
module
def
_load_config
(
self
,
model_path
:
str
)
->
Dict
[
str
,
Any
]:
model_path
=
maybe_download_model
(
self
.
model_path
)
self
.
model_path
=
model_path
# fastvideo_args.downloaded_model_path = model_path
logger
.
info
(
"Model path: %s"
,
model_path
)
config
=
verify_model_config_and_directory
(
model_path
)
return
cast
(
Dict
[
str
,
Any
],
config
)
@
property
def
required_config_modules
(
self
)
->
List
[
str
]:
"""
List of modules that are required by the pipeline. The names should match
the diffusers directory and model_index.json file. These modules will be
loaded using the PipelineComponentLoader and made available in the
modules dictionary. Access these modules using the get_module method.
class ConcretePipeline(ComposedPipelineBase):
_required_config_modules = ["vae", "text_encoder", "transformer", "scheduler", "tokenizer"]
@property
def required_config_modules(self):
return self._required_config_modules
"""
return
self
.
_required_config_modules
@
property
def
stages
(
self
)
->
List
[
PipelineStage
]:
"""
List of stages in the pipeline.
"""
return
self
.
_stages
@
abstractmethod
def
create_pipeline_stages
(
self
,
fastvideo_args
:
FastVideoArgs
):
"""
Create the pipeline stages.
"""
raise
NotImplementedError
def
initialize_pipeline
(
self
,
fastvideo_args
:
FastVideoArgs
):
"""
Initialize the pipeline.
"""
return
def
load_modules
(
self
,
fastvideo_args
:
FastVideoArgs
)
->
Dict
[
str
,
Any
]:
"""
Load the modules from the config.
"""
logger
.
info
(
"Loading pipeline modules from config: %s"
,
self
.
config
)
modules_config
=
deepcopy
(
self
.
config
)
# remove keys that are not pipeline modules
modules_config
.
pop
(
"_class_name"
)
modules_config
.
pop
(
"_diffusers_version"
)
# some sanity checks
assert
len
(
modules_config
)
>
1
,
"model_index.json must contain at least one pipeline module"
required_modules
=
[
"vae"
,
"text_encoder"
,
"transformer"
,
"scheduler"
,
"tokenizer"
]
for
module_name
in
required_modules
:
if
module_name
not
in
modules_config
:
raise
ValueError
(
f
"model_index.json must contain a
{
module_name
}
module"
)
logger
.
info
(
"Diffusers config passed sanity checks"
)
# all the component models used by the pipeline
modules
=
{}
for
module_name
,
(
transformers_or_diffusers
,
architecture
)
in
modules_config
.
items
():
component_model_path
=
os
.
path
.
join
(
self
.
model_path
,
module_name
)
module
=
PipelineComponentLoader
.
load_module
(
module_name
=
module_name
,
component_model_path
=
component_model_path
,
transformers_or_diffusers
=
transformers_or_diffusers
,
architecture
=
architecture
,
fastvideo_args
=
fastvideo_args
,
)
logger
.
info
(
"Loaded module %s from %s"
,
module_name
,
component_model_path
)
if
module_name
in
modules
:
logger
.
warning
(
"Overwriting module %s"
,
module_name
)
modules
[
module_name
]
=
module
required_modules
=
self
.
required_config_modules
# Check if all required modules were loaded
for
module_name
in
required_modules
:
if
module_name
not
in
modules
or
modules
[
module_name
]
is
None
:
raise
ValueError
(
f
"Required module
{
module_name
}
was not loaded properly"
)
return
modules
def
add_stage
(
self
,
stage_name
:
str
,
stage
:
PipelineStage
):
assert
self
.
modules
is
not
None
,
"No modules are registered"
self
.
_stages
.
append
(
stage
)
self
.
_stage_name_mapping
[
stage_name
]
=
stage
setattr
(
self
,
stage_name
,
stage
)
# TODO(will): don't hardcode no_grad
@
torch
.
no_grad
()
def
forward
(
self
,
batch
:
ForwardBatch
,
fastvideo_args
:
FastVideoArgs
,
)
->
ForwardBatch
:
"""
Generate a video or image using the pipeline.
Args:
batch: The batch to generate from.
fastvideo_args: The inference arguments.
Returns:
ForwardBatch: The batch with the generated video or image.
"""
# Execute each stage
logger
.
info
(
"Running pipeline stages: %s"
,
self
.
_stage_name_mapping
.
keys
())
logger
.
info
(
"Batch: %s"
,
batch
)
for
stage
in
self
.
stages
:
batch
=
stage
(
batch
,
fastvideo_args
)
# Return the output
return
batch
FastVideo-main/fastvideo/v1/pipelines/hunyuan/__init__.py
0 → 100644
View file @
c07946d8
Prev
1
…
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment