Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
TS-MODELS-OPT
training
Video-Generation-Model
Commits
c07946d8
Commit
c07946d8
authored
Apr 09, 2026
by
hepj
Browse files
dit & video
parents
Changes
270
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4026 additions
and
0 deletions
+4026
-0
FastVideo-main/fastvideo/models/stepvideo/__version__.py
FastVideo-main/fastvideo/models/stepvideo/__version__.py
+1
-0
FastVideo-main/fastvideo/models/stepvideo/config.py
FastVideo-main/fastvideo/models/stepvideo/config.py
+174
-0
FastVideo-main/fastvideo/models/stepvideo/diffusion/scheduler.py
...eo-main/fastvideo/models/stepvideo/diffusion/scheduler.py
+220
-0
FastVideo-main/fastvideo/models/stepvideo/diffusion/video_pipeline.py
...in/fastvideo/models/stepvideo/diffusion/video_pipeline.py
+325
-0
FastVideo-main/fastvideo/models/stepvideo/modules/attentions.py
...deo-main/fastvideo/models/stepvideo/modules/attentions.py
+96
-0
FastVideo-main/fastvideo/models/stepvideo/modules/blocks.py
FastVideo-main/fastvideo/models/stepvideo/modules/blocks.py
+296
-0
FastVideo-main/fastvideo/models/stepvideo/modules/model.py
FastVideo-main/fastvideo/models/stepvideo/modules/model.py
+760
-0
FastVideo-main/fastvideo/models/stepvideo/modules/model.py-bak
...ideo-main/fastvideo/models/stepvideo/modules/model.py-bak
+198
-0
FastVideo-main/fastvideo/models/stepvideo/modules/model.py-new
...ideo-main/fastvideo/models/stepvideo/modules/model.py-new
+760
-0
FastVideo-main/fastvideo/models/stepvideo/modules/normalization.py
...-main/fastvideo/models/stepvideo/modules/normalization.py
+312
-0
FastVideo-main/fastvideo/models/stepvideo/modules/rope.py
FastVideo-main/fastvideo/models/stepvideo/modules/rope.py
+90
-0
FastVideo-main/fastvideo/models/stepvideo/parallel.py
FastVideo-main/fastvideo/models/stepvideo/parallel.py
+21
-0
FastVideo-main/fastvideo/models/stepvideo/text_encoder/__init__.py
...-main/fastvideo/models/stepvideo/text_encoder/__init__.py
+12
-0
FastVideo-main/fastvideo/models/stepvideo/text_encoder/clip.py
...ideo-main/fastvideo/models/stepvideo/text_encoder/clip.py
+36
-0
FastVideo-main/fastvideo/models/stepvideo/text_encoder/flashattention.py
...fastvideo/models/stepvideo/text_encoder/flashattention.py
+45
-0
FastVideo-main/fastvideo/models/stepvideo/text_encoder/stepllm.py
...o-main/fastvideo/models/stepvideo/text_encoder/stepllm.py
+291
-0
FastVideo-main/fastvideo/models/stepvideo/text_encoder/tokenizer.py
...main/fastvideo/models/stepvideo/text_encoder/tokenizer.py
+209
-0
FastVideo-main/fastvideo/models/stepvideo/utils/__init__.py
FastVideo-main/fastvideo/models/stepvideo/utils/__init__.py
+3
-0
FastVideo-main/fastvideo/models/stepvideo/utils/quantization.py
...deo-main/fastvideo/models/stepvideo/utils/quantization.py
+117
-0
FastVideo-main/fastvideo/models/stepvideo/utils/utils.py
FastVideo-main/fastvideo/models/stepvideo/utils/utils.py
+60
-0
No files found.
Too many changes to show.
To preserve performance only
270 of 270+
files are displayed.
Plain diff
Email patch
FastVideo-main/fastvideo/models/stepvideo/__version__.py
0 → 100644
View file @
c07946d8
__version__
=
"0.1.0"
FastVideo-main/fastvideo/models/stepvideo/config.py
0 → 100644
View file @
c07946d8
import
argparse
def
parse_args
(
namespace
=
None
):
parser
=
argparse
.
ArgumentParser
(
description
=
"StepVideo inference script"
)
parser
=
add_extra_models_args
(
parser
)
parser
=
add_denoise_schedule_args
(
parser
)
parser
=
add_inference_args
(
parser
)
parser
=
add_parallel_args
(
parser
)
args
=
parser
.
parse_args
(
namespace
=
namespace
)
return
args
def
add_extra_models_args
(
parser
:
argparse
.
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
"Extra models args, including vae, text encoders and tokenizers)"
)
group
.
add_argument
(
"--vae_url"
,
type
=
str
,
default
=
'127.0.0.1'
,
help
=
"vae url."
,
)
group
.
add_argument
(
"--caption_url"
,
type
=
str
,
default
=
'127.0.0.1'
,
help
=
"caption url."
,
)
return
parser
def
add_denoise_schedule_args
(
parser
:
argparse
.
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
"Denoise schedule args"
)
# Flow Matching
group
.
add_argument
(
"--time_shift"
,
type
=
float
,
default
=
7.0
,
help
=
"Shift factor for flow matching schedulers."
,
)
group
.
add_argument
(
"--flow_reverse"
,
action
=
"store_true"
,
help
=
"If reverse, learning/sampling from t=1 -> t=0."
,
)
group
.
add_argument
(
"--flow_solver"
,
type
=
str
,
default
=
"euler"
,
help
=
"Solver for flow matching."
,
)
return
parser
def
add_inference_args
(
parser
:
argparse
.
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
"Inference args"
)
# ======================== Model loads ========================
group
.
add_argument
(
"--model_dir"
,
type
=
str
,
default
=
"./ckpts"
,
help
=
"Root path of all the models, including t2v models and extra models."
,
)
group
.
add_argument
(
"--model_resolution"
,
type
=
str
,
default
=
"540p"
,
choices
=
[
"540p"
],
help
=
"Root path of all the models, including t2v models and extra models."
,
)
group
.
add_argument
(
"--use-cpu-offload"
,
action
=
"store_true"
,
help
=
"Use CPU offload for the model load."
,
)
# ======================== Inference general setting ========================
group
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size for inference and evaluation."
,
)
group
.
add_argument
(
"--infer_steps"
,
type
=
int
,
default
=
50
,
help
=
"Number of denoising steps for inference."
,
)
group
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"./results"
,
help
=
"Path to save the generated samples."
,
)
group
.
add_argument
(
"--name_suffix"
,
type
=
str
,
default
=
""
,
help
=
"Suffix for the names of saved samples."
,
)
group
.
add_argument
(
"--num_videos"
,
type
=
int
,
default
=
1
,
help
=
"Number of videos to generate for each prompt."
,
)
# ---sample size---
group
.
add_argument
(
"--num_frames"
,
type
=
int
,
default
=
204
,
help
=
"How many frames to sample from a video. "
,
)
group
.
add_argument
(
"--height"
,
type
=
int
,
default
=
544
,
help
=
"The height of video sample"
,
)
group
.
add_argument
(
"--width"
,
type
=
int
,
default
=
992
,
help
=
"The width of video sample"
,
)
# --- prompt ---
group
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
None
,
help
=
"Prompt for sampling during evaluation."
,
)
group
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1234
,
help
=
"Seed for evaluation."
)
# Classifier-Free Guidance
group
.
add_argument
(
"--pos_magic"
,
type
=
str
,
default
=
"超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。"
,
help
=
"Positive magic prompt for sampling."
)
group
.
add_argument
(
"--neg_magic"
,
type
=
str
,
default
=
"画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。"
,
help
=
"Negative magic prompt for sampling."
)
group
.
add_argument
(
"--cfg_scale"
,
type
=
float
,
default
=
9.0
,
help
=
"Classifier free guidance scale."
)
return
parser
def
add_parallel_args
(
parser
:
argparse
.
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
"Parallel args"
)
# ======================== Model loads ========================
group
.
add_argument
(
"--ulysses_degree"
,
type
=
int
,
default
=
8
,
help
=
"Ulysses degree."
,
)
group
.
add_argument
(
"--ring_degree"
,
type
=
int
,
default
=
1
,
help
=
"Ulysses degree."
,
)
return
parser
FastVideo-main/fastvideo/models/stepvideo/diffusion/scheduler.py
0 → 100644
View file @
c07946d8
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
import
torch
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.schedulers.scheduling_utils
import
SchedulerMixin
from
diffusers.utils
import
BaseOutput
,
logging
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
@
dataclass
class
FlowMatchDiscreteSchedulerOutput
(
BaseOutput
):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample
:
torch
.
FloatTensor
class
FlowMatchDiscreteScheduler
(
SchedulerMixin
,
ConfigMixin
):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
reverse (`bool`, defaults to `True`):
Whether to reverse the timestep schedule.
"""
_compatibles
=
[]
order
=
1
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
:
int
=
1000
,
reverse
:
bool
=
False
,
solver
:
str
=
"euler"
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
):
sigmas
=
torch
.
linspace
(
1
,
0
,
num_train_timesteps
+
1
)
if
not
reverse
:
sigmas
=
sigmas
.
flip
(
0
)
self
.
sigmas
=
sigmas
# the value fed to model
self
.
timesteps
=
(
sigmas
[:
-
1
]
*
num_train_timesteps
).
to
(
dtype
=
torch
.
float32
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
device
=
device
self
.
supported_solver
=
[
"euler"
]
if
solver
not
in
self
.
supported_solver
:
raise
ValueError
(
f
"Solver
{
solver
}
not supported. Supported solvers:
{
self
.
supported_solver
}
"
)
@
property
def
step_index
(
self
):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
_sigma_to_t
(
self
,
sigma
):
return
sigma
*
self
.
config
.
num_train_timesteps
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
time_shift
:
float
=
13.0
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
n_tokens (`int`, *optional*):
Number of tokens in the input sequence.
"""
device
=
device
or
self
.
device
self
.
num_inference_steps
=
num_inference_steps
sigmas
=
torch
.
linspace
(
1
,
0
,
num_inference_steps
+
1
,
device
=
device
)
sigmas
=
self
.
sd3_time_shift
(
sigmas
,
time_shift
)
if
not
self
.
config
.
reverse
:
sigmas
=
1
-
sigmas
self
.
sigmas
=
sigmas
self
.
timesteps
=
sigmas
[:
-
1
]
# Reset step index
self
.
_step_index
=
None
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos
=
1
if
len
(
indices
)
>
1
else
0
return
indices
[
pos
].
item
()
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
scale_model_input
(
self
,
sample
:
torch
.
Tensor
,
timestep
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
sample
def
sd3_time_shift
(
self
,
t
:
torch
.
Tensor
,
time_shift
:
float
=
13.0
):
return
(
time_shift
*
t
)
/
(
1
+
(
time_shift
-
1
)
*
t
)
def
step
(
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
],
sample
:
torch
.
FloatTensor
,
return_dict
:
bool
=
False
,
)
->
Union
[
FlowMatchDiscreteSchedulerOutput
,
Tuple
]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
n_tokens (`int`, *optional*):
Number of tokens in the input sequence.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if
(
isinstance
(
timestep
,
int
)
or
isinstance
(
timestep
,
torch
.
IntTensor
)
or
isinstance
(
timestep
,
torch
.
LongTensor
)):
raise
ValueError
((
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
# Upcast to avoid precision issues when computing prev_sample
sample
=
sample
.
to
(
torch
.
float32
)
dt
=
self
.
sigmas
[
self
.
step_index
+
1
]
-
self
.
sigmas
[
self
.
step_index
]
if
self
.
config
.
solver
==
"euler"
:
prev_sample
=
sample
+
model_output
.
to
(
torch
.
float32
)
*
dt
else
:
raise
ValueError
(
f
"Solver
{
self
.
config
.
solver
}
not supported. Supported solvers:
{
self
.
supported_solver
}
"
)
# upon completion increase step index by one
self
.
_step_index
+=
1
if
not
return_dict
:
return
prev_sample
return
FlowMatchDiscreteSchedulerOutput
(
prev_sample
=
prev_sample
)
def
__len__
(
self
):
return
self
.
config
.
num_train_timesteps
FastVideo-main/fastvideo/models/stepvideo/diffusion/video_pipeline.py
0 → 100644
View file @
c07946d8
# Copyright 2025 StepFun Inc. All Rights Reserved.
import
asyncio
import
pickle
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
torch
from
diffusers.pipelines.pipeline_utils
import
DiffusionPipeline
from
diffusers.utils
import
BaseOutput
from
fastvideo.models.stepvideo.diffusion.scheduler
import
FlowMatchDiscreteScheduler
from
fastvideo.models.stepvideo.modules.model
import
StepVideoModel
from
fastvideo.models.stepvideo.utils
import
VideoProcessor
def
call_api_gen
(
url
,
api
,
port
=
8080
):
url
=
f
"http://
{
url
}
:
{
port
}
/
{
api
}
-api"
import
aiohttp
async
def
_fn
(
samples
,
*
args
,
**
kwargs
):
if
api
==
'vae'
:
data
=
{
"samples"
:
samples
,
}
elif
api
==
'caption'
:
data
=
{
"prompts"
:
samples
,
}
else
:
raise
Exception
(
f
"Not supported api:
{
api
}
..."
)
async
with
aiohttp
.
ClientSession
()
as
sess
:
data_bytes
=
pickle
.
dumps
(
data
)
async
with
sess
.
get
(
url
,
data
=
data_bytes
,
timeout
=
12000
)
as
response
:
result
=
bytearray
()
while
not
response
.
content
.
at_eof
():
chunk
=
await
response
.
content
.
read
(
1024
)
result
+=
chunk
response_data
=
pickle
.
loads
(
result
)
return
response_data
return
_fn
@
dataclass
class
StepVideoPipelineOutput
(
BaseOutput
):
video
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
class
StepVideoPipeline
(
DiffusionPipeline
):
r
"""
Pipeline for text-to-video generation using StepVideo.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
transformer ([`StepVideoModel`]):
Conditional Transformer to denoise the encoded image latents.
scheduler ([`FlowMatchDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae_url:
remote vae server's url.
caption_url:
remote caption (stepllm and clip) server's url.
"""
def
__init__
(
self
,
transformer
:
StepVideoModel
,
scheduler
:
FlowMatchDiscreteScheduler
,
vae_url
:
str
=
'127.0.0.1'
,
caption_url
:
str
=
'127.0.0.1'
,
save_path
:
str
=
'./results'
,
name_suffix
:
str
=
''
,
):
super
().
__init__
()
self
.
register_modules
(
transformer
=
transformer
,
scheduler
=
scheduler
,
)
self
.
vae_scale_factor_temporal
=
self
.
vae
.
temporal_compression_ratio
if
getattr
(
self
,
"vae"
,
None
)
else
8
self
.
vae_scale_factor_spatial
=
self
.
vae
.
spatial_compression_ratio
if
getattr
(
self
,
"vae"
,
None
)
else
16
self
.
video_processor
=
VideoProcessor
(
save_path
,
name_suffix
)
self
.
vae_url
=
vae_url
self
.
caption_url
=
caption_url
self
.
setup_api
(
self
.
vae_url
,
self
.
caption_url
)
def
setup_api
(
self
,
vae_url
,
caption_url
):
self
.
vae_url
=
vae_url
self
.
caption_url
=
caption_url
self
.
caption
=
call_api_gen
(
caption_url
,
'caption'
)
self
.
vae
=
call_api_gen
(
vae_url
,
'vae'
)
return
self
def
encode_prompt
(
self
,
prompt
:
str
,
neg_magic
:
str
=
''
,
pos_magic
:
str
=
''
,
):
device
=
self
.
_execution_device
prompts
=
[
prompt
+
pos_magic
]
bs
=
len
(
prompts
)
prompts
+=
[
neg_magic
]
*
bs
data
=
asyncio
.
run
(
self
.
caption
(
prompts
))
prompt_embeds
,
prompt_attention_mask
,
clip_embedding
=
data
[
'y'
].
to
(
device
),
data
[
'y_mask'
].
to
(
device
),
data
[
'clip_embedding'
].
to
(
device
)
return
prompt_embeds
,
clip_embedding
,
prompt_attention_mask
def
decode_vae
(
self
,
samples
):
samples
=
asyncio
.
run
(
self
.
vae
(
samples
.
cpu
()))
return
samples
def
check_inputs
(
self
,
num_frames
,
width
,
height
):
num_frames
=
max
(
num_frames
//
17
*
17
,
1
)
width
=
max
(
width
//
16
*
16
,
16
)
height
=
max
(
height
//
16
*
16
,
16
)
return
num_frames
,
width
,
height
def
prepare_latents
(
self
,
batch_size
:
int
,
num_channels_latents
:
64
,
height
:
int
=
544
,
width
:
int
=
992
,
num_frames
:
int
=
204
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
generator
:
Optional
[
Union
[
torch
.
Generator
,
List
[
torch
.
Generator
]]]
=
None
,
latents
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
latents
is
not
None
:
return
latents
.
to
(
device
=
device
,
dtype
=
dtype
)
num_frames
,
width
,
height
=
self
.
check_inputs
(
num_frames
,
width
,
height
)
shape
=
(
batch_size
,
max
(
num_frames
//
17
*
3
,
1
),
num_channels_latents
,
int
(
height
)
//
self
.
vae_scale_factor_spatial
,
int
(
width
)
//
self
.
vae_scale_factor_spatial
,
)
# b,f,c,h,w
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
raise
ValueError
(
f
"You have passed a list of generators of length
{
len
(
generator
)
}
, but requested an effective batch"
f
" size of
{
batch_size
}
. Make sure the batch size matches the length of the generators."
)
if
generator
is
None
:
generator
=
torch
.
Generator
(
device
=
self
.
_execution_device
)
latents
=
torch
.
randn
(
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
dtype
)
return
latents
@
torch
.
inference_mode
()
def
__call__
(
self
,
prompt
:
Union
[
str
,
List
[
str
]]
=
None
,
height
:
int
=
544
,
width
:
int
=
992
,
num_frames
:
int
=
204
,
num_inference_steps
:
int
=
50
,
guidance_scale
:
float
=
9.0
,
time_shift
:
float
=
13.0
,
neg_magic
:
str
=
""
,
pos_magic
:
str
=
""
,
num_videos_per_prompt
:
Optional
[
int
]
=
1
,
generator
:
Optional
[
Union
[
torch
.
Generator
,
List
[
torch
.
Generator
]]]
=
None
,
latents
:
Optional
[
torch
.
Tensor
]
=
None
,
output_type
:
Optional
[
str
]
=
"mp4"
,
output_file_name
:
Optional
[
str
]
=
""
,
return_dict
:
bool
=
True
,
mask_strategy
:
Optional
[
Dict
[
str
,
list
]]
=
None
,
):
r
"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `544`):
The height in pixels of the generated image.
width (`int`, defaults to `992`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `204`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `9.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
output_file_name(`str`, *optional*`):
The output mp4 file name.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`StepVideoPipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~StepVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`StepVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
where the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
# 1. Check inputs. Raise error if not correct
device
=
self
.
_execution_device
# 2. Define call parameters
if
prompt
is
not
None
and
isinstance
(
prompt
,
str
):
batch_size
=
1
elif
prompt
is
not
None
and
isinstance
(
prompt
,
list
):
batch_size
=
len
(
prompt
)
else
:
batch_size
=
prompt_embeds
.
shape
[
0
]
do_classifier_free_guidance
=
guidance_scale
>
1.0
# 3. Encode input prompt
prompt_embeds
,
prompt_embeds_2
,
prompt_attention_mask
=
self
.
encode_prompt
(
prompt
=
prompt
,
neg_magic
=
neg_magic
,
pos_magic
=
pos_magic
,
)
transformer_dtype
=
self
.
transformer
.
dtype
prompt_embeds
=
prompt_embeds
.
to
(
transformer_dtype
)
prompt_attention_mask
=
prompt_attention_mask
.
to
(
transformer_dtype
)
prompt_embeds_2
=
prompt_embeds_2
.
to
(
transformer_dtype
)
# 4. Prepare timesteps
self
.
scheduler
.
set_timesteps
(
num_inference_steps
=
num_inference_steps
,
time_shift
=
time_shift
,
device
=
device
)
# 5. Prepare latent variables
num_channels_latents
=
self
.
transformer
.
config
.
in_channels
latents
=
self
.
prepare_latents
(
batch_size
*
num_videos_per_prompt
,
num_channels_latents
,
height
,
width
,
num_frames
,
torch
.
bfloat16
,
device
,
generator
,
latents
,
)
def
dict_to_3d_list
(
best_masks
,
t_max
=
50
,
l_max
=
48
,
h_max
=
48
):
result
=
[[[
None
for
_
in
range
(
h_max
)]
for
_
in
range
(
l_max
)]
for
_
in
range
(
t_max
)]
if
best_masks
is
None
:
return
result
for
key
,
value
in
best_masks
.
items
():
timestep
,
layer
,
head
=
map
(
int
,
key
.
split
(
'_'
))
result
[
timestep
][
layer
][
head
]
=
value
return
result
mask_strategy
=
dict_to_3d_list
(
mask_strategy
)
#best_mask_selections = None
# 7. Denoising loop
with
self
.
progress_bar
(
total
=
num_inference_steps
)
as
progress_bar
:
for
i
,
t
in
enumerate
(
self
.
scheduler
.
timesteps
):
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
latent_model_input
=
latent_model_input
.
to
(
transformer_dtype
)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep
=
t
.
expand
(
latent_model_input
.
shape
[
0
]).
to
(
latent_model_input
.
dtype
)
noise_pred
=
self
.
transformer
(
hidden_states
=
latent_model_input
,
timestep
=
timestep
,
encoder_hidden_states
=
prompt_embeds
,
encoder_attention_mask
=
prompt_attention_mask
,
encoder_hidden_states_2
=
prompt_embeds_2
,
return_dict
=
False
,
mask_strategy
=
mask_strategy
[
i
],
)
# perform guidance
if
do_classifier_free_guidance
:
noise_pred_text
,
noise_pred_uncond
=
noise_pred
.
chunk
(
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
model_output
=
noise_pred
,
timestep
=
t
,
sample
=
latents
)
progress_bar
.
update
()
if
not
torch
.
distributed
.
is_initialized
()
or
int
(
torch
.
distributed
.
get_rank
())
==
0
:
if
not
output_type
==
"latent"
:
video
=
self
.
decode_vae
(
latents
)
video
=
self
.
video_processor
.
postprocess_video
(
video
,
output_file_name
=
output_file_name
,
output_type
=
output_type
)
else
:
video
=
latents
# Offload all models
self
.
maybe_free_model_hooks
()
if
not
return_dict
:
return
(
video
,
)
return
StepVideoPipelineOutput
(
video
=
video
)
FastVideo-main/fastvideo/models/stepvideo/modules/attentions.py
0 → 100644
View file @
c07946d8
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
flash_attn
import
flash_attn_func
try
:
from
st_attn
import
sliding_tile_attention
except
ImportError
:
print
(
"Could not load Sliding Tile Attention."
)
sliding_tile_attention
=
None
from
fastvideo.utils.communications
import
all_to_all_4D
from
fastvideo.utils.parallel_states
import
get_sequence_parallel_state
,
nccl_info
class
Attention
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
attn_processor
(
self
,
attn_type
):
if
attn_type
==
'torch'
:
return
self
.
torch_attn_func
elif
attn_type
==
'parallel'
:
return
self
.
parallel_attn_func
else
:
raise
Exception
(
'Not supported attention type...'
)
def
tile
(
self
,
x
,
sp_size
):
x
=
rearrange
(
x
,
"b (sp t h w) head d -> b (t sp h w) head d"
,
sp
=
sp_size
,
t
=
36
//
sp_size
,
h
=
48
,
w
=
48
)
return
rearrange
(
x
,
"b (n_t ts_t n_h ts_h n_w ts_w) h d -> b (n_t n_h n_w ts_t ts_h ts_w) h d"
,
n_t
=
6
,
n_h
=
6
,
n_w
=
6
,
ts_t
=
6
,
ts_h
=
8
,
ts_w
=
8
)
def
untile
(
self
,
x
,
sp_size
):
x
=
rearrange
(
x
,
"b (n_t n_h n_w ts_t ts_h ts_w) h d -> b (n_t ts_t n_h ts_h n_w ts_w) h d"
,
n_t
=
6
,
n_h
=
6
,
n_w
=
6
,
ts_t
=
6
,
ts_h
=
8
,
ts_w
=
8
)
return
rearrange
(
x
,
"b (t sp h w) head d -> b (sp t h w) head d"
,
sp
=
sp_size
,
t
=
36
//
sp_size
,
h
=
48
,
w
=
48
)
def
torch_attn_func
(
self
,
q
,
k
,
v
,
attn_mask
=
None
,
causal
=
False
,
drop_rate
=
0.0
,
**
kwargs
):
if
attn_mask
is
not
None
and
attn_mask
.
dtype
!=
torch
.
bool
:
attn_mask
=
attn_mask
.
to
(
q
.
dtype
)
if
attn_mask
is
not
None
and
attn_mask
.
ndim
==
3
:
## no head
n_heads
=
q
.
shape
[
2
]
attn_mask
=
attn_mask
.
unsqueeze
(
1
).
repeat
(
1
,
n_heads
,
1
,
1
)
q
,
k
,
v
=
map
(
lambda
x
:
rearrange
(
x
,
'b s h d -> b h s d'
),
(
q
,
k
,
v
))
x
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attn_mask
,
dropout_p
=
drop_rate
,
is_causal
=
causal
)
x
=
rearrange
(
x
,
'b h s d -> b s h d'
)
return
x
def
parallel_attn_func
(
self
,
q
,
k
,
v
,
causal
=
False
,
mask_strategy
=
None
,
**
kwargs
):
if
get_sequence_parallel_state
():
q
=
all_to_all_4D
(
q
,
scatter_dim
=
2
,
gather_dim
=
1
)
k
=
all_to_all_4D
(
k
,
scatter_dim
=
2
,
gather_dim
=
1
)
v
=
all_to_all_4D
(
v
,
scatter_dim
=
2
,
gather_dim
=
1
)
if
mask_strategy
[
0
]
is
not
None
:
q
=
self
.
tile
(
q
,
nccl_info
.
sp_size
).
transpose
(
1
,
2
).
contiguous
()
k
=
self
.
tile
(
k
,
nccl_info
.
sp_size
).
transpose
(
1
,
2
).
contiguous
()
v
=
self
.
tile
(
v
,
nccl_info
.
sp_size
).
transpose
(
1
,
2
).
contiguous
()
head_num
=
q
.
size
(
1
)
# 48 // sp_size
current_rank
=
nccl_info
.
rank_within_group
start_head
=
current_rank
*
head_num
windows
=
[
mask_strategy
[
head_idx
+
start_head
]
for
head_idx
in
range
(
head_num
)]
x
=
sliding_tile_attention
(
q
,
k
,
v
,
windows
,
0
,
False
).
transpose
(
1
,
2
).
contiguous
()
x
=
self
.
untile
(
x
,
nccl_info
.
sp_size
)
else
:
x
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
)
if
get_sequence_parallel_state
():
x
=
all_to_all_4D
(
x
,
scatter_dim
=
1
,
gather_dim
=
2
)
x
=
x
.
to
(
q
.
dtype
)
return
x
FastVideo-main/fastvideo/models/stepvideo/modules/blocks.py
0 → 100644
View file @
c07946d8
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
fastvideo.models.stepvideo.modules.attentions
import
Attention
from
fastvideo.models.stepvideo.modules.normalization
import
RMSNorm
from
fastvideo.models.stepvideo.modules.rope
import
RoPE3D
class
SelfAttention
(
Attention
):
def
__init__
(
self
,
hidden_dim
,
head_dim
,
bias
=
False
,
with_rope
=
True
,
with_qk_norm
=
True
,
attn_type
=
'torch'
):
super
().
__init__
()
self
.
head_dim
=
head_dim
self
.
n_heads
=
hidden_dim
//
head_dim
self
.
wqkv
=
nn
.
Linear
(
hidden_dim
,
hidden_dim
*
3
,
bias
=
bias
)
self
.
wo
=
nn
.
Linear
(
hidden_dim
,
hidden_dim
,
bias
=
bias
)
self
.
with_rope
=
with_rope
self
.
with_qk_norm
=
with_qk_norm
if
self
.
with_qk_norm
:
self
.
q_norm
=
RMSNorm
(
head_dim
,
elementwise_affine
=
True
)
self
.
k_norm
=
RMSNorm
(
head_dim
,
elementwise_affine
=
True
)
if
self
.
with_rope
:
self
.
rope_3d
=
RoPE3D
(
freq
=
1e4
,
F0
=
1.0
,
scaling_factor
=
1.0
)
self
.
rope_ch_split
=
[
64
,
32
,
32
]
self
.
core_attention
=
self
.
attn_processor
(
attn_type
=
attn_type
)
self
.
parallel
=
attn_type
==
'parallel'
def
apply_rope3d
(
self
,
x
,
fhw_positions
,
rope_ch_split
,
parallel
=
True
):
x
=
self
.
rope_3d
(
x
,
fhw_positions
,
rope_ch_split
,
parallel
)
return
x
def
forward
(
self
,
x
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
rope_positions
=
None
,
attn_mask
=
None
,
mask_strategy
=
None
):
xqkv
=
self
.
wqkv
(
x
)
xqkv
=
xqkv
.
view
(
*
x
.
shape
[:
-
1
],
self
.
n_heads
,
3
*
self
.
head_dim
)
xq
,
xk
,
xv
=
torch
.
split
(
xqkv
,
[
self
.
head_dim
]
*
3
,
dim
=-
1
)
## seq_len, n, dim
if
self
.
with_qk_norm
:
xq
=
self
.
q_norm
(
xq
)
xk
=
self
.
k_norm
(
xk
)
if
self
.
with_rope
:
xq
=
self
.
apply_rope3d
(
xq
,
rope_positions
,
self
.
rope_ch_split
,
parallel
=
self
.
parallel
)
xk
=
self
.
apply_rope3d
(
xk
,
rope_positions
,
self
.
rope_ch_split
,
parallel
=
self
.
parallel
)
output
=
self
.
core_attention
(
xq
,
xk
,
xv
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
attn_mask
=
attn_mask
,
mask_strategy
=
mask_strategy
)
output
=
rearrange
(
output
,
'b s h d -> b s (h d)'
)
output
=
self
.
wo
(
output
)
return
output
class
CrossAttention
(
Attention
):
def
__init__
(
self
,
hidden_dim
,
head_dim
,
bias
=
False
,
with_qk_norm
=
True
,
attn_type
=
'torch'
):
super
().
__init__
()
self
.
head_dim
=
head_dim
self
.
n_heads
=
hidden_dim
//
head_dim
self
.
wq
=
nn
.
Linear
(
hidden_dim
,
hidden_dim
,
bias
=
bias
)
self
.
wkv
=
nn
.
Linear
(
hidden_dim
,
hidden_dim
*
2
,
bias
=
bias
)
self
.
wo
=
nn
.
Linear
(
hidden_dim
,
hidden_dim
,
bias
=
bias
)
self
.
with_qk_norm
=
with_qk_norm
if
self
.
with_qk_norm
:
self
.
q_norm
=
RMSNorm
(
head_dim
,
elementwise_affine
=
True
)
self
.
k_norm
=
RMSNorm
(
head_dim
,
elementwise_affine
=
True
)
self
.
core_attention
=
self
.
attn_processor
(
attn_type
=
attn_type
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
attn_mask
=
None
):
xq
=
self
.
wq
(
x
)
xq
=
xq
.
view
(
*
xq
.
shape
[:
-
1
],
self
.
n_heads
,
self
.
head_dim
)
xkv
=
self
.
wkv
(
encoder_hidden_states
)
xkv
=
xkv
.
view
(
*
xkv
.
shape
[:
-
1
],
self
.
n_heads
,
2
*
self
.
head_dim
)
xk
,
xv
=
torch
.
split
(
xkv
,
[
self
.
head_dim
]
*
2
,
dim
=-
1
)
## seq_len, n, dim
if
self
.
with_qk_norm
:
xq
=
self
.
q_norm
(
xq
)
xk
=
self
.
k_norm
(
xk
)
output
=
self
.
core_attention
(
xq
,
xk
,
xv
,
attn_mask
=
attn_mask
)
output
=
rearrange
(
output
,
'b s h d -> b s (h d)'
)
output
=
self
.
wo
(
output
)
return
output
class
GELU
(
nn
.
Module
):
r
"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def
__init__
(
self
,
dim_in
:
int
,
dim_out
:
int
,
approximate
:
str
=
"none"
,
bias
:
bool
=
True
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
,
bias
=
bias
)
self
.
approximate
=
approximate
def
gelu
(
self
,
gate
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
nn
.
functional
.
gelu
(
gate
,
approximate
=
self
.
approximate
)
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
proj
(
hidden_states
)
hidden_states
=
self
.
gelu
(
hidden_states
)
return
hidden_states
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
inner_dim
:
Optional
[
int
]
=
None
,
dim_out
:
Optional
[
int
]
=
None
,
mult
:
int
=
4
,
bias
:
bool
=
False
,
):
super
().
__init__
()
inner_dim
=
dim
*
mult
if
inner_dim
is
None
else
inner_dim
dim_out
=
dim
if
dim_out
is
None
else
dim_out
self
.
net
=
nn
.
ModuleList
([
GELU
(
dim
,
inner_dim
,
approximate
=
"tanh"
,
bias
=
bias
),
nn
.
Identity
(),
nn
.
Linear
(
inner_dim
,
dim_out
,
bias
=
bias
)
])
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
for
module
in
self
.
net
:
hidden_states
=
module
(
hidden_states
)
return
hidden_states
def
modulate
(
x
,
scale
,
shift
):
x
=
x
*
(
1
+
scale
)
+
shift
return
x
def
gate
(
x
,
gate
):
x
=
gate
*
x
return
x
class
StepVideoTransformerBlock
(
nn
.
Module
):
r
"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def
__init__
(
self
,
dim
:
int
,
attention_head_dim
:
int
,
norm_eps
:
float
=
1e-5
,
ff_inner_dim
:
Optional
[
int
]
=
None
,
ff_bias
:
bool
=
False
,
attention_type
:
str
=
'parallel'
):
super
().
__init__
()
self
.
dim
=
dim
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
eps
=
norm_eps
)
self
.
attn1
=
SelfAttention
(
dim
,
attention_head_dim
,
bias
=
False
,
with_rope
=
True
,
with_qk_norm
=
True
,
attn_type
=
attention_type
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
eps
=
norm_eps
)
self
.
attn2
=
CrossAttention
(
dim
,
attention_head_dim
,
bias
=
False
,
with_qk_norm
=
True
,
attn_type
=
'torch'
)
self
.
ff
=
FeedForward
(
dim
=
dim
,
inner_dim
=
ff_inner_dim
,
dim_out
=
dim
,
bias
=
ff_bias
)
self
.
scale_shift_table
=
nn
.
Parameter
(
torch
.
randn
(
6
,
dim
)
/
dim
**
0.5
)
@
torch
.
no_grad
()
def
forward
(
self
,
q
:
torch
.
Tensor
,
kv
:
Optional
[
torch
.
Tensor
]
=
None
,
timestep
:
Optional
[
torch
.
LongTensor
]
=
None
,
attn_mask
=
None
,
rope_positions
:
list
=
None
,
mask_strategy
=
None
)
->
torch
.
Tensor
:
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
(
torch
.
clone
(
chunk
)
for
chunk
in
(
self
.
scale_shift_table
[
None
]
+
timestep
.
reshape
(
-
1
,
6
,
self
.
dim
)).
chunk
(
6
,
dim
=
1
))
scale_shift_q
=
modulate
(
self
.
norm1
(
q
),
scale_msa
,
shift_msa
)
attn_q
=
self
.
attn1
(
scale_shift_q
,
rope_positions
=
rope_positions
,
mask_strategy
=
mask_strategy
)
q
=
gate
(
attn_q
,
gate_msa
)
+
q
attn_q
=
self
.
attn2
(
q
,
kv
,
attn_mask
)
q
=
attn_q
+
q
scale_shift_q
=
modulate
(
self
.
norm2
(
q
),
scale_mlp
,
shift_mlp
)
ff_output
=
self
.
ff
(
scale_shift_q
)
q
=
gate
(
ff_output
,
gate_mlp
)
+
q
return
q
class
PatchEmbed
(
nn
.
Module
):
"""2D Image to Patch Embedding"""
def
__init__
(
self
,
patch_size
=
64
,
in_channels
=
3
,
embed_dim
=
768
,
layer_norm
=
False
,
flatten
=
True
,
bias
=
True
,
):
super
().
__init__
()
self
.
flatten
=
flatten
self
.
layer_norm
=
layer_norm
self
.
proj
=
nn
.
Conv2d
(
in_channels
,
embed_dim
,
kernel_size
=
(
patch_size
,
patch_size
),
stride
=
patch_size
,
bias
=
bias
)
def
forward
(
self
,
latent
):
latent
=
self
.
proj
(
latent
).
to
(
latent
.
dtype
)
if
self
.
flatten
:
latent
=
latent
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
if
self
.
layer_norm
:
latent
=
self
.
norm
(
latent
)
return
latent
FastVideo-main/fastvideo/models/stepvideo/modules/model.py
0 → 100644
View file @
c07946d8
from
typing
import
Any
,
List
,
Tuple
,
Optional
,
Union
,
Dict
from
einops
import
rearrange
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
diffusers.models
import
ModelMixin
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
.activation_layers
import
get_activation_layer
from
.norm_layers
import
get_norm_layer
from
.embed_layers
import
TimestepEmbedder
,
PatchEmbed
,
TextProjection
from
.attenion
import
attention
,
parallel_attention
,
get_cu_seqlens
from
.posemb_layers
import
apply_rotary_emb
from
.mlp_layers
import
MLP
,
MLPEmbedder
,
FinalLayer
from
.modulate_layers
import
ModulateDiT
,
modulate
,
apply_gate
from
.token_refiner
import
SingleTokenRefiner
class
MMDoubleStreamBlock
(
nn
.
Module
):
"""
A multimodal dit block with seperate modulation for
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def
__init__
(
self
,
hidden_size
:
int
,
heads_num
:
int
,
mlp_width_ratio
:
float
,
mlp_act_type
:
str
=
"gelu_tanh"
,
qk_norm
:
bool
=
True
,
qk_norm_type
:
str
=
"rms"
,
qkv_bias
:
bool
=
False
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
deterministic
=
False
self
.
heads_num
=
heads_num
head_dim
=
hidden_size
//
heads_num
mlp_hidden_dim
=
int
(
hidden_size
*
mlp_width_ratio
)
self
.
img_mod
=
ModulateDiT
(
hidden_size
,
factor
=
6
,
act_layer
=
get_activation_layer
(
"silu"
),
**
factory_kwargs
,
)
self
.
img_norm1
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
,
**
factory_kwargs
)
self
.
img_attn_qkv
=
nn
.
Linear
(
hidden_size
,
hidden_size
*
3
,
bias
=
qkv_bias
,
**
factory_kwargs
)
qk_norm_layer
=
get_norm_layer
(
qk_norm_type
)
self
.
img_attn_q_norm
=
(
qk_norm_layer
(
head_dim
,
eps
=
1e-6
,
**
factory_kwargs
)
if
qk_norm
else
nn
.
Identity
()
)
self
.
img_attn_k_norm
=
(
qk_norm_layer
(
head_dim
,
eps
=
1e-6
,
**
factory_kwargs
)
if
qk_norm
else
nn
.
Identity
()
)
self
.
img_attn_proj
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
qkv_bias
,
**
factory_kwargs
)
self
.
img_norm2
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
,
**
factory_kwargs
)
self
.
img_mlp
=
MLP
(
hidden_size
,
mlp_hidden_dim
,
act_layer
=
get_activation_layer
(
mlp_act_type
),
bias
=
True
,
**
factory_kwargs
,
)
self
.
txt_mod
=
ModulateDiT
(
hidden_size
,
factor
=
6
,
act_layer
=
get_activation_layer
(
"silu"
),
**
factory_kwargs
,
)
self
.
txt_norm1
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
,
**
factory_kwargs
)
self
.
txt_attn_qkv
=
nn
.
Linear
(
hidden_size
,
hidden_size
*
3
,
bias
=
qkv_bias
,
**
factory_kwargs
)
self
.
txt_attn_q_norm
=
(
qk_norm_layer
(
head_dim
,
eps
=
1e-6
,
**
factory_kwargs
)
if
qk_norm
else
nn
.
Identity
()
)
self
.
txt_attn_k_norm
=
(
qk_norm_layer
(
head_dim
,
eps
=
1e-6
,
**
factory_kwargs
)
if
qk_norm
else
nn
.
Identity
()
)
self
.
txt_attn_proj
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
qkv_bias
,
**
factory_kwargs
)
self
.
txt_norm2
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
,
**
factory_kwargs
)
self
.
txt_mlp
=
MLP
(
hidden_size
,
mlp_hidden_dim
,
act_layer
=
get_activation_layer
(
mlp_act_type
),
bias
=
True
,
**
factory_kwargs
,
)
self
.
hybrid_seq_parallel_attn
=
None
def
enable_deterministic
(
self
):
self
.
deterministic
=
True
def
disable_deterministic
(
self
):
self
.
deterministic
=
False
def
forward
(
self
,
img
:
torch
.
Tensor
,
txt
:
torch
.
Tensor
,
vec
:
torch
.
Tensor
,
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_kv
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen_q
:
Optional
[
int
]
=
None
,
max_seqlen_kv
:
Optional
[
int
]
=
None
,
freqs_cis
:
tuple
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
(
img_mod1_shift
,
img_mod1_scale
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
)
=
self
.
img_mod
(
vec
).
chunk
(
6
,
dim
=-
1
)
(
txt_mod1_shift
,
txt_mod1_scale
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
=
self
.
txt_mod
(
vec
).
chunk
(
6
,
dim
=-
1
)
# Prepare image for attention.
img_modulated
=
self
.
img_norm1
(
img
)
img_modulated
=
modulate
(
img_modulated
,
shift
=
img_mod1_shift
,
scale
=
img_mod1_scale
)
img_qkv
=
self
.
img_attn_qkv
(
img_modulated
)
img_q
,
img_k
,
img_v
=
rearrange
(
img_qkv
,
"B L (K H D) -> K B L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
# Apply QK-Norm if needed
img_q
=
self
.
img_attn_q_norm
(
img_q
).
to
(
img_v
)
img_k
=
self
.
img_attn_k_norm
(
img_k
).
to
(
img_v
)
# Apply RoPE if needed.
if
freqs_cis
is
not
None
:
img_qq
,
img_kk
=
apply_rotary_emb
(
img_q
,
img_k
,
freqs_cis
,
head_first
=
False
)
assert
(
img_qq
.
shape
==
img_q
.
shape
and
img_kk
.
shape
==
img_k
.
shape
),
f
"img_kk:
{
img_qq
.
shape
}
, img_q:
{
img_q
.
shape
}
, img_kk:
{
img_kk
.
shape
}
, img_k:
{
img_k
.
shape
}
"
img_q
,
img_k
=
img_qq
,
img_kk
# Prepare txt for attention.
txt_modulated
=
self
.
txt_norm1
(
txt
)
txt_modulated
=
modulate
(
txt_modulated
,
shift
=
txt_mod1_shift
,
scale
=
txt_mod1_scale
)
txt_qkv
=
self
.
txt_attn_qkv
(
txt_modulated
)
txt_q
,
txt_k
,
txt_v
=
rearrange
(
txt_qkv
,
"B L (K H D) -> K B L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
# Apply QK-Norm if needed.
txt_q
=
self
.
txt_attn_q_norm
(
txt_q
).
to
(
txt_v
)
txt_k
=
self
.
txt_attn_k_norm
(
txt_k
).
to
(
txt_v
)
# Run actual attention.
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
1
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
1
)
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
1
)
assert
(
cu_seqlens_q
.
shape
[
0
]
==
2
*
img
.
shape
[
0
]
+
1
),
f
"cu_seqlens_q.shape:
{
cu_seqlens_q
.
shape
}
, img.shape[0]:
{
img
.
shape
[
0
]
}
"
# attention computation start
if
not
self
.
hybrid_seq_parallel_attn
:
attn
=
attention
(
q
,
k
,
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_kv
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_kv
=
max_seqlen_kv
,
batch_size
=
img_k
.
shape
[
0
],
)
else
:
attn
=
parallel_attention
(
self
.
hybrid_seq_parallel_attn
,
q
,
k
,
v
,
img_q_len
=
img_q
.
shape
[
1
],
img_kv_len
=
img_k
.
shape
[
1
],
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_kv
)
# attention computation end
img_attn
,
txt_attn
=
attn
[:,
:
img
.
shape
[
1
]],
attn
[:,
img
.
shape
[
1
]
:]
# Calculate the img bloks.
img
=
img
+
apply_gate
(
self
.
img_attn_proj
(
img_attn
),
gate
=
img_mod1_gate
)
img
=
img
+
apply_gate
(
self
.
img_mlp
(
modulate
(
self
.
img_norm2
(
img
),
shift
=
img_mod2_shift
,
scale
=
img_mod2_scale
)
),
gate
=
img_mod2_gate
,
)
# Calculate the txt bloks.
txt
=
txt
+
apply_gate
(
self
.
txt_attn_proj
(
txt_attn
),
gate
=
txt_mod1_gate
)
txt
=
txt
+
apply_gate
(
self
.
txt_mlp
(
modulate
(
self
.
txt_norm2
(
txt
),
shift
=
txt_mod2_shift
,
scale
=
txt_mod2_scale
)
),
gate
=
txt_mod2_gate
,
)
return
img
,
txt
class
MMSingleStreamBlock
(
nn
.
Module
):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
Also refer to (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def
__init__
(
self
,
hidden_size
:
int
,
heads_num
:
int
,
mlp_width_ratio
:
float
=
4.0
,
mlp_act_type
:
str
=
"gelu_tanh"
,
qk_norm
:
bool
=
True
,
qk_norm_type
:
str
=
"rms"
,
qk_scale
:
float
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
deterministic
=
False
self
.
hidden_size
=
hidden_size
self
.
heads_num
=
heads_num
head_dim
=
hidden_size
//
heads_num
mlp_hidden_dim
=
int
(
hidden_size
*
mlp_width_ratio
)
self
.
mlp_hidden_dim
=
mlp_hidden_dim
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
# qkv and mlp_in
self
.
linear1
=
nn
.
Linear
(
hidden_size
,
hidden_size
*
3
+
mlp_hidden_dim
,
**
factory_kwargs
)
# proj and mlp_out
self
.
linear2
=
nn
.
Linear
(
hidden_size
+
mlp_hidden_dim
,
hidden_size
,
**
factory_kwargs
)
qk_norm_layer
=
get_norm_layer
(
qk_norm_type
)
self
.
q_norm
=
(
qk_norm_layer
(
head_dim
,
eps
=
1e-6
,
**
factory_kwargs
)
if
qk_norm
else
nn
.
Identity
()
)
self
.
k_norm
=
(
qk_norm_layer
(
head_dim
,
eps
=
1e-6
,
**
factory_kwargs
)
if
qk_norm
else
nn
.
Identity
()
)
self
.
pre_norm
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
,
**
factory_kwargs
)
self
.
mlp_act
=
get_activation_layer
(
mlp_act_type
)()
self
.
modulation
=
ModulateDiT
(
hidden_size
,
factor
=
3
,
act_layer
=
get_activation_layer
(
"silu"
),
**
factory_kwargs
,
)
self
.
hybrid_seq_parallel_attn
=
None
def
enable_deterministic
(
self
):
self
.
deterministic
=
True
def
disable_deterministic
(
self
):
self
.
deterministic
=
False
def
forward
(
self
,
x
:
torch
.
Tensor
,
vec
:
torch
.
Tensor
,
txt_len
:
int
,
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_kv
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen_q
:
Optional
[
int
]
=
None
,
max_seqlen_kv
:
Optional
[
int
]
=
None
,
freqs_cis
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
mod_shift
,
mod_scale
,
mod_gate
=
self
.
modulation
(
vec
).
chunk
(
3
,
dim
=-
1
)
x_mod
=
modulate
(
self
.
pre_norm
(
x
),
shift
=
mod_shift
,
scale
=
mod_scale
)
qkv
,
mlp
=
torch
.
split
(
self
.
linear1
(
x_mod
),
[
3
*
self
.
hidden_size
,
self
.
mlp_hidden_dim
],
dim
=-
1
)
q
,
k
,
v
=
rearrange
(
qkv
,
"B L (K H D) -> K B L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
# Apply QK-Norm if needed.
q
=
self
.
q_norm
(
q
).
to
(
v
)
k
=
self
.
k_norm
(
k
).
to
(
v
)
# Apply RoPE if needed.
if
freqs_cis
is
not
None
:
img_q
,
txt_q
=
q
[:,
:
-
txt_len
,
:,
:],
q
[:,
-
txt_len
:,
:,
:]
img_k
,
txt_k
=
k
[:,
:
-
txt_len
,
:,
:],
k
[:,
-
txt_len
:,
:,
:]
img_qq
,
img_kk
=
apply_rotary_emb
(
img_q
,
img_k
,
freqs_cis
,
head_first
=
False
)
assert
(
img_qq
.
shape
==
img_q
.
shape
and
img_kk
.
shape
==
img_k
.
shape
),
f
"img_kk:
{
img_qq
.
shape
}
, img_q:
{
img_q
.
shape
}
, img_kk:
{
img_kk
.
shape
}
, img_k:
{
img_k
.
shape
}
"
img_q
,
img_k
=
img_qq
,
img_kk
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
1
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
1
)
# Compute attention.
assert
(
cu_seqlens_q
.
shape
[
0
]
==
2
*
x
.
shape
[
0
]
+
1
),
f
"cu_seqlens_q.shape:
{
cu_seqlens_q
.
shape
}
, x.shape[0]:
{
x
.
shape
[
0
]
}
"
# attention computation start
if
not
self
.
hybrid_seq_parallel_attn
:
attn
=
attention
(
q
,
k
,
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_kv
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_kv
=
max_seqlen_kv
,
batch_size
=
x
.
shape
[
0
],
)
else
:
attn
=
parallel_attention
(
self
.
hybrid_seq_parallel_attn
,
q
,
k
,
v
,
img_q_len
=
img_q
.
shape
[
1
],
img_kv_len
=
img_k
.
shape
[
1
],
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_kv
)
# attention computation end
# Compute activation in mlp stream, cat again and run second linear layer.
output
=
self
.
linear2
(
torch
.
cat
((
attn
,
self
.
mlp_act
(
mlp
)),
2
))
return
x
+
apply_gate
(
output
,
gate
=
mod_gate
)
class
HYVideoDiffusionTransformer
(
ModelMixin
,
ConfigMixin
):
"""
HunyuanVideo Transformer backbone
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
Reference:
[1] Flux.1: https://github.com/black-forest-labs/flux
[2] MMDiT: http://arxiv.org/abs/2403.03206
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
patch_size: list
The size of the patch.
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
hidden_size: int
The hidden size of the transformer backbone.
heads_num: int
The number of attention heads.
mlp_width_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
mlp_act_type: str
The activation function of the MLP in the transformer block.
depth_double_blocks: int
The number of transformer blocks in the double blocks.
depth_single_blocks: int
The number of transformer blocks in the single blocks.
rope_dim_list: list
The dimension of the rotary embedding for t, h, w.
qkv_bias: bool
Whether to use bias in the qkv linear layer.
qk_norm: bool
Whether to use qk norm.
qk_norm_type: str
The type of qk norm.
guidance_embed: bool
Whether to use guidance embedding for distillation.
text_projection: str
The type of the text projection, default is single_refiner.
use_attention_mask: bool
Whether to use attention mask for text encoder.
dtype: torch.dtype
The dtype of the model.
device: torch.device
The device of the model.
"""
@
register_to_config
def
__init__
(
self
,
args
:
Any
,
patch_size
:
list
=
[
1
,
2
,
2
],
in_channels
:
int
=
4
,
# Should be VAE.config.latent_channels.
out_channels
:
int
=
None
,
hidden_size
:
int
=
3072
,
heads_num
:
int
=
24
,
mlp_width_ratio
:
float
=
4.0
,
mlp_act_type
:
str
=
"gelu_tanh"
,
mm_double_blocks_depth
:
int
=
20
,
mm_single_blocks_depth
:
int
=
40
,
rope_dim_list
:
List
[
int
]
=
[
16
,
56
,
56
],
qkv_bias
:
bool
=
True
,
qk_norm
:
bool
=
True
,
qk_norm_type
:
str
=
"rms"
,
guidance_embed
:
bool
=
False
,
# For modulation.
text_projection
:
str
=
"single_refiner"
,
use_attention_mask
:
bool
=
True
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
patch_size
=
patch_size
self
.
in_channels
=
in_channels
self
.
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
unpatchify_channels
=
self
.
out_channels
self
.
guidance_embed
=
guidance_embed
self
.
rope_dim_list
=
rope_dim_list
# Text projection. Default to linear projection.
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
self
.
use_attention_mask
=
use_attention_mask
self
.
text_projection
=
text_projection
self
.
text_states_dim
=
args
.
text_states_dim
self
.
text_states_dim_2
=
args
.
text_states_dim_2
if
hidden_size
%
heads_num
!=
0
:
raise
ValueError
(
f
"Hidden size
{
hidden_size
}
must be divisible by heads_num
{
heads_num
}
"
)
pe_dim
=
hidden_size
//
heads_num
if
sum
(
rope_dim_list
)
!=
pe_dim
:
raise
ValueError
(
f
"Got
{
rope_dim_list
}
but expected positional dim
{
pe_dim
}
"
)
self
.
hidden_size
=
hidden_size
self
.
heads_num
=
heads_num
# image projection
self
.
img_in
=
PatchEmbed
(
self
.
patch_size
,
self
.
in_channels
,
self
.
hidden_size
,
**
factory_kwargs
)
# text projection
if
self
.
text_projection
==
"linear"
:
self
.
txt_in
=
TextProjection
(
self
.
text_states_dim
,
self
.
hidden_size
,
get_activation_layer
(
"silu"
),
**
factory_kwargs
,
)
elif
self
.
text_projection
==
"single_refiner"
:
self
.
txt_in
=
SingleTokenRefiner
(
self
.
text_states_dim
,
hidden_size
,
heads_num
,
depth
=
2
,
**
factory_kwargs
)
else
:
raise
NotImplementedError
(
f
"Unsupported text_projection:
{
self
.
text_projection
}
"
)
# time modulation
self
.
time_in
=
TimestepEmbedder
(
self
.
hidden_size
,
get_activation_layer
(
"silu"
),
**
factory_kwargs
)
# text modulation
self
.
vector_in
=
MLPEmbedder
(
self
.
text_states_dim_2
,
self
.
hidden_size
,
**
factory_kwargs
)
# guidance modulation
self
.
guidance_in
=
(
TimestepEmbedder
(
self
.
hidden_size
,
get_activation_layer
(
"silu"
),
**
factory_kwargs
)
if
guidance_embed
else
None
)
# double blocks
self
.
double_blocks
=
nn
.
ModuleList
(
[
MMDoubleStreamBlock
(
self
.
hidden_size
,
self
.
heads_num
,
mlp_width_ratio
=
mlp_width_ratio
,
mlp_act_type
=
mlp_act_type
,
qk_norm
=
qk_norm
,
qk_norm_type
=
qk_norm_type
,
qkv_bias
=
qkv_bias
,
**
factory_kwargs
,
)
for
_
in
range
(
mm_double_blocks_depth
)
]
)
# single blocks
self
.
single_blocks
=
nn
.
ModuleList
(
[
MMSingleStreamBlock
(
self
.
hidden_size
,
self
.
heads_num
,
mlp_width_ratio
=
mlp_width_ratio
,
mlp_act_type
=
mlp_act_type
,
qk_norm
=
qk_norm
,
qk_norm_type
=
qk_norm_type
,
**
factory_kwargs
,
)
for
_
in
range
(
mm_single_blocks_depth
)
]
)
self
.
final_layer
=
FinalLayer
(
self
.
hidden_size
,
self
.
patch_size
,
self
.
out_channels
,
get_activation_layer
(
"silu"
),
**
factory_kwargs
,
)
def
enable_deterministic
(
self
):
for
block
in
self
.
double_blocks
:
block
.
enable_deterministic
()
for
block
in
self
.
single_blocks
:
block
.
enable_deterministic
()
def
disable_deterministic
(
self
):
for
block
in
self
.
double_blocks
:
block
.
disable_deterministic
()
for
block
in
self
.
single_blocks
:
block
.
disable_deterministic
()
def
forward
(
self
,
x
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
# Should be in range(0, 1000).
text_states
:
torch
.
Tensor
=
None
,
text_mask
:
torch
.
Tensor
=
None
,
# Now we don't use it.
text_states_2
:
Optional
[
torch
.
Tensor
]
=
None
,
# Text embedding for modulation.
freqs_cos
:
Optional
[
torch
.
Tensor
]
=
None
,
freqs_sin
:
Optional
[
torch
.
Tensor
]
=
None
,
guidance
:
torch
.
Tensor
=
None
,
# Guidance for modulation, should be cfg_scale x 1000.
return_dict
:
bool
=
True
,
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
out
=
{}
img
=
x
txt
=
text_states
_
,
_
,
ot
,
oh
,
ow
=
x
.
shape
tt
,
th
,
tw
=
(
ot
//
self
.
patch_size
[
0
],
oh
//
self
.
patch_size
[
1
],
ow
//
self
.
patch_size
[
2
],
)
# Prepare modulation vectors.
vec
=
self
.
time_in
(
t
)
# text modulation
vec
=
vec
+
self
.
vector_in
(
text_states_2
)
# guidance modulation
if
self
.
guidance_embed
:
if
guidance
is
None
:
raise
ValueError
(
"Didn't get guidance strength for guidance distilled model."
)
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
vec
=
vec
+
self
.
guidance_in
(
guidance
)
# Embed image and text.
img
=
self
.
img_in
(
img
)
if
self
.
text_projection
==
"linear"
:
txt
=
self
.
txt_in
(
txt
)
elif
self
.
text_projection
==
"single_refiner"
:
txt
=
self
.
txt_in
(
txt
,
t
,
text_mask
if
self
.
use_attention_mask
else
None
)
else
:
raise
NotImplementedError
(
f
"Unsupported text_projection:
{
self
.
text_projection
}
"
)
txt_seq_len
=
txt
.
shape
[
1
]
img_seq_len
=
img
.
shape
[
1
]
# Compute cu_squlens and max_seqlen for flash attention
cu_seqlens_q
=
get_cu_seqlens
(
text_mask
,
img_seq_len
)
cu_seqlens_kv
=
cu_seqlens_q
max_seqlen_q
=
img_seq_len
+
txt_seq_len
max_seqlen_kv
=
max_seqlen_q
freqs_cis
=
(
freqs_cos
,
freqs_sin
)
if
freqs_cos
is
not
None
else
None
# --------------------- Pass through DiT blocks ------------------------
for
_
,
block
in
enumerate
(
self
.
double_blocks
):
double_block_args
=
[
img
,
txt
,
vec
,
cu_seqlens_q
,
cu_seqlens_kv
,
max_seqlen_q
,
max_seqlen_kv
,
freqs_cis
,
]
img
,
txt
=
block
(
*
double_block_args
)
# Merge txt and img to pass through single stream blocks.
x
=
torch
.
cat
((
img
,
txt
),
1
)
if
len
(
self
.
single_blocks
)
>
0
:
for
_
,
block
in
enumerate
(
self
.
single_blocks
):
single_block_args
=
[
x
,
vec
,
txt_seq_len
,
cu_seqlens_q
,
cu_seqlens_kv
,
max_seqlen_q
,
max_seqlen_kv
,
(
freqs_cos
,
freqs_sin
),
]
x
=
block
(
*
single_block_args
)
img
=
x
[:,
:
img_seq_len
,
...]
# ---------------------------- Final layer ------------------------------
img
=
self
.
final_layer
(
img
,
vec
)
# (N, T, patch_size ** 2 * out_channels)
img
=
self
.
unpatchify
(
img
,
tt
,
th
,
tw
)
if
return_dict
:
out
[
"x"
]
=
img
return
out
return
img
def
unpatchify
(
self
,
x
,
t
,
h
,
w
):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c
=
self
.
unpatchify_channels
pt
,
ph
,
pw
=
self
.
patch_size
assert
t
*
h
*
w
==
x
.
shape
[
1
]
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
t
,
h
,
w
,
c
,
pt
,
ph
,
pw
))
x
=
torch
.
einsum
(
"nthwcopq->nctohpwq"
,
x
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
c
,
t
*
pt
,
h
*
ph
,
w
*
pw
))
return
imgs
def
params_count
(
self
):
counts
=
{
"double"
:
sum
(
[
sum
(
p
.
numel
()
for
p
in
block
.
img_attn_qkv
.
parameters
())
+
sum
(
p
.
numel
()
for
p
in
block
.
img_attn_proj
.
parameters
())
+
sum
(
p
.
numel
()
for
p
in
block
.
img_mlp
.
parameters
())
+
sum
(
p
.
numel
()
for
p
in
block
.
txt_attn_qkv
.
parameters
())
+
sum
(
p
.
numel
()
for
p
in
block
.
txt_attn_proj
.
parameters
())
+
sum
(
p
.
numel
()
for
p
in
block
.
txt_mlp
.
parameters
())
for
block
in
self
.
double_blocks
]
),
"single"
:
sum
(
[
sum
(
p
.
numel
()
for
p
in
block
.
linear1
.
parameters
())
+
sum
(
p
.
numel
()
for
p
in
block
.
linear2
.
parameters
())
for
block
in
self
.
single_blocks
]
),
"total"
:
sum
(
p
.
numel
()
for
p
in
self
.
parameters
()),
}
counts
[
"attn+mlp"
]
=
counts
[
"double"
]
+
counts
[
"single"
]
return
counts
#################################################################################
# HunyuanVideo Configs #
#################################################################################
HUNYUAN_VIDEO_CONFIG
=
{
"HYVideo-T/2"
:
{
"mm_double_blocks_depth"
:
20
,
"mm_single_blocks_depth"
:
40
,
"rope_dim_list"
:
[
16
,
56
,
56
],
"hidden_size"
:
3072
,
"heads_num"
:
24
,
"mlp_width_ratio"
:
4
,
},
"HYVideo-T/2-cfgdistill"
:
{
"mm_double_blocks_depth"
:
20
,
"mm_single_blocks_depth"
:
40
,
"rope_dim_list"
:
[
16
,
56
,
56
],
"hidden_size"
:
3072
,
"heads_num"
:
24
,
"mlp_width_ratio"
:
4
,
"guidance_embed"
:
True
,
},
}
FastVideo-main/fastvideo/models/stepvideo/modules/model.py-bak
0 → 100644
View file @
c07946d8
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
from typing import Dict, Optional
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from einops import rearrange, repeat
from torch import nn
from fastvideo.models.stepvideo.modules.blocks import PatchEmbed, StepVideoTransformerBlock
from fastvideo.models.stepvideo.modules.normalization import AdaLayerNormSingle, PixArtAlphaTextProjection
from fastvideo.models.stepvideo.parallel import parallel_forward
from fastvideo.models.stepvideo.utils import with_empty_init
class StepVideoModel(ModelMixin, ConfigMixin):
_no_split_modules = ["StepVideoTransformerBlock", "PatchEmbed"]
@with_empty_init
@register_to_config
def __init__(
self,
num_attention_heads: int = 48,
attention_head_dim: int = 128,
in_channels: int = 64,
out_channels: Optional[int] = 64,
num_layers: int = 48,
dropout: float = 0.0,
patch_size: int = 1,
norm_type: str = "ada_norm_single",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
use_additional_conditions: Optional[bool] = False,
caption_channels: Optional[int] | list | tuple = [6144, 1024],
attention_type: Optional[str] = "parallel",
):
super().__init__()
# Set some common variables used across the board.
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.out_channels = in_channels if out_channels is None else out_channels
self.use_additional_conditions = use_additional_conditions
self.pos_embed = PatchEmbed(
patch_size=patch_size,
in_channels=self.config.in_channels,
embed_dim=self.inner_dim,
)
self.transformer_blocks = nn.ModuleList([
StepVideoTransformerBlock(dim=self.inner_dim,
attention_head_dim=self.config.attention_head_dim,
attention_type=attention_type) for _ in range(self.config.num_layers)
])
# 3. Output blocks.
self.norm_out = nn.LayerNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels)
self.patch_size = patch_size
self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=self.use_additional_conditions)
if isinstance(self.config.caption_channels, int):
caption_channel = self.config.caption_channels
else:
caption_channel, clip_channel = self.config.caption_channels
self.clip_projection = nn.Linear(clip_channel, self.inner_dim)
self.caption_norm = nn.LayerNorm(caption_channel, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channel, hidden_size=self.inner_dim)
self.parallel = attention_type == 'parallel'
def patchfy(self, hidden_states):
hidden_states = rearrange(hidden_states, 'b f c h w -> (b f) c h w')
hidden_states = self.pos_embed(hidden_states)
return hidden_states
def prepare_attn_mask(self, encoder_attention_mask, encoder_hidden_states, q_seqlen):
kv_seqlens = encoder_attention_mask.sum(dim=1).int()
mask = torch.zeros([len(kv_seqlens), q_seqlen, max(kv_seqlens)],
dtype=torch.bool,
device=encoder_attention_mask.device)
encoder_hidden_states = encoder_hidden_states[:, :max(kv_seqlens)]
for i, kv_len in enumerate(kv_seqlens):
mask[i, :, :kv_len] = 1
return encoder_hidden_states, mask
@parallel_forward
def block_forward(self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
rope_positions=None,
attn_mask=None,
parallel=True,
mask_strategy=None):
for i, block in enumerate(self.transformer_blocks):
hidden_states = block(hidden_states,
encoder_hidden_states,
timestep=timestep,
attn_mask=attn_mask,
rope_positions=rope_positions,
mask_strategy=mask_strategy[i])
return hidden_states
@torch.inference_mode()
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_hidden_states_2: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
fps: torch.Tensor = None,
return_dict: bool = True,
mask_strategy=None,
):
assert hidden_states.ndim == 5
"hidden_states's shape should be (bsz, f, ch, h ,w)"
bsz, frame, _, height, width = hidden_states.shape
height, width = height // self.patch_size, width // self.patch_size
hidden_states = self.patchfy(hidden_states)
len_frame = hidden_states.shape[1]
if self.use_additional_conditions:
added_cond_kwargs = {
"resolution": torch.tensor([(height, width)] * bsz,
device=hidden_states.device,
dtype=hidden_states.dtype),
"nframe": torch.tensor([frame] * bsz, device=hidden_states.device, dtype=hidden_states.dtype),
"fps": fps
}
else:
added_cond_kwargs = {}
timestep, embedded_timestep = self.adaln_single(timestep, added_cond_kwargs=added_cond_kwargs)
encoder_hidden_states = self.caption_projection(self.caption_norm(encoder_hidden_states))
if encoder_hidden_states_2 is not None and hasattr(self, 'clip_projection'):
clip_embedding = self.clip_projection(encoder_hidden_states_2)
encoder_hidden_states = torch.cat([clip_embedding, encoder_hidden_states], dim=1)
hidden_states = rearrange(hidden_states, '(b f) l d-> b (f l) d', b=bsz, f=frame, l=len_frame).contiguous()
encoder_hidden_states, attn_mask = self.prepare_attn_mask(encoder_attention_mask,
encoder_hidden_states,
q_seqlen=frame * len_frame)
hidden_states = self.block_forward(hidden_states,
encoder_hidden_states,
timestep=timestep,
rope_positions=[frame, height, width],
attn_mask=attn_mask,
parallel=self.parallel,
mask_strategy=mask_strategy)
hidden_states = rearrange(hidden_states, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame)
embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame).contiguous()
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# unpatchify
hidden_states = hidden_states.reshape(shape=(-1, height, width, self.patch_size, self.patch_size,
self.out_channels))
hidden_states = rearrange(hidden_states, 'n h w p q c -> n c h p w q')
output = hidden_states.reshape(shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size))
output = rearrange(output, '(b f) c h w -> b f c h w', f=frame)
if return_dict:
return {'x': output}
return output
FastVideo-main/fastvideo/models/stepvideo/modules/model.py-new
0 → 100644
View file @
c07946d8
from typing import Any, List, Tuple, Optional, Union, Dict
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from .activation_layers import get_activation_layer
from .norm_layers import get_norm_layer
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
from .attenion import attention, parallel_attention, get_cu_seqlens
from .posemb_layers import apply_rotary_emb
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
from .modulate_layers import ModulateDiT, modulate, apply_gate
from .token_refiner import SingleTokenRefiner
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal dit block with seperate modulation for
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qkv_bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.img_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.img_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.img_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.img_attn_q_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_k_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.img_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.img_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.txt_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
self.txt_attn_q_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_k_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.txt_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
freqs_cis: tuple = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(vec).chunk(6, dim=-1)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(vec).chunk(6, dim=-1)
# Prepare image for attention.
img_modulated = self.img_norm1(img)
img_modulated = modulate(
img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
)
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
# Apply RoPE if needed.
if freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
# Prepare txt for attention.
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(
txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed.
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
# Run actual attention.
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
assert (
cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
# attention computation start
if not self.hybrid_seq_parallel_attn:
attn = attention(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
batch_size=img_k.shape[0],
)
else:
attn = parallel_attention(
self.hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv
)
# attention computation end
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
# Calculate the img bloks.
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
img = img + apply_gate(
self.img_mlp(
modulate(
self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
)
),
gate=img_mod2_gate,
)
# Calculate the txt bloks.
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
txt = txt + apply_gate(
self.txt_mlp(
modulate(
self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
)
),
gate=txt_mod2_gate,
)
return img, txt
class MMSingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
Also refer to (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.mlp_hidden_dim = mlp_hidden_dim
self.scale = qk_scale or head_dim ** -0.5
# qkv and mlp_in
self.linear1 = nn.Linear(
hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
)
# proj and mlp_out
self.linear2 = nn.Linear(
hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.pre_norm = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.mlp_act = get_activation_layer(mlp_act_type)()
self.modulation = ModulateDiT(
hidden_size,
factor=3,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
) -> torch.Tensor:
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
qkv, mlp = torch.split(
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
# Apply QK-Norm if needed.
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
# Apply RoPE if needed.
if freqs_cis is not None:
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
# Compute attention.
assert (
cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
# attention computation start
if not self.hybrid_seq_parallel_attn:
attn = attention(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
batch_size=x.shape[0],
)
else:
attn = parallel_attention(
self.hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv
)
# attention computation end
# Compute activation in mlp stream, cat again and run second linear layer.
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + apply_gate(output, gate=mod_gate)
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
"""
HunyuanVideo Transformer backbone
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
Reference:
[1] Flux.1: https://github.com/black-forest-labs/flux
[2] MMDiT: http://arxiv.org/abs/2403.03206
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
patch_size: list
The size of the patch.
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
hidden_size: int
The hidden size of the transformer backbone.
heads_num: int
The number of attention heads.
mlp_width_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
mlp_act_type: str
The activation function of the MLP in the transformer block.
depth_double_blocks: int
The number of transformer blocks in the double blocks.
depth_single_blocks: int
The number of transformer blocks in the single blocks.
rope_dim_list: list
The dimension of the rotary embedding for t, h, w.
qkv_bias: bool
Whether to use bias in the qkv linear layer.
qk_norm: bool
Whether to use qk norm.
qk_norm_type: str
The type of qk norm.
guidance_embed: bool
Whether to use guidance embedding for distillation.
text_projection: str
The type of the text projection, default is single_refiner.
use_attention_mask: bool
Whether to use attention mask for text encoder.
dtype: torch.dtype
The dtype of the model.
device: torch.device
The device of the model.
"""
@register_to_config
def __init__(
self,
args: Any,
patch_size: list = [1, 2, 2],
in_channels: int = 4, # Should be VAE.config.latent_channels.
out_channels: int = None,
hidden_size: int = 3072,
heads_num: int = 24,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
mm_double_blocks_depth: int = 20,
mm_single_blocks_depth: int = 40,
rope_dim_list: List[int] = [16, 56, 56],
qkv_bias: bool = True,
qk_norm: bool = True,
qk_norm_type: str = "rms",
guidance_embed: bool = False, # For modulation.
text_projection: str = "single_refiner",
use_attention_mask: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.unpatchify_channels = self.out_channels
self.guidance_embed = guidance_embed
self.rope_dim_list = rope_dim_list
# Text projection. Default to linear projection.
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
self.use_attention_mask = use_attention_mask
self.text_projection = text_projection
self.text_states_dim = args.text_states_dim
self.text_states_dim_2 = args.text_states_dim_2
if hidden_size % heads_num != 0:
raise ValueError(
f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
)
pe_dim = hidden_size // heads_num
if sum(rope_dim_list) != pe_dim:
raise ValueError(
f"Got {rope_dim_list} but expected positional dim {pe_dim}"
)
self.hidden_size = hidden_size
self.heads_num = heads_num
# image projection
self.img_in = PatchEmbed(
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
)
# text projection
if self.text_projection == "linear":
self.txt_in = TextProjection(
self.text_states_dim,
self.hidden_size,
get_activation_layer("silu"),
**factory_kwargs,
)
elif self.text_projection == "single_refiner":
self.txt_in = SingleTokenRefiner(
self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
)
else:
raise NotImplementedError(
f"Unsupported text_projection: {self.text_projection}"
)
# time modulation
self.time_in = TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
# text modulation
self.vector_in = MLPEmbedder(
self.text_states_dim_2, self.hidden_size, **factory_kwargs
)
# guidance modulation
self.guidance_in = (
TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
if guidance_embed
else None
)
# double blocks
self.double_blocks = nn.ModuleList(
[
MMDoubleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,
)
for _ in range(mm_double_blocks_depth)
]
)
# single blocks
self.single_blocks = nn.ModuleList(
[
MMSingleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
**factory_kwargs,
)
for _ in range(mm_single_blocks_depth)
]
)
self.final_layer = FinalLayer(
self.hidden_size,
self.patch_size,
self.out_channels,
get_activation_layer("silu"),
**factory_kwargs,
)
def enable_deterministic(self):
for block in self.double_blocks:
block.enable_deterministic()
for block in self.single_blocks:
block.enable_deterministic()
def disable_deterministic(self):
for block in self.double_blocks:
block.disable_deterministic()
for block in self.single_blocks:
block.disable_deterministic()
def forward(
self,
x: torch.Tensor,
t: torch.Tensor, # Should be in range(0, 1000).
text_states: torch.Tensor = None,
text_mask: torch.Tensor = None, # Now we don't use it.
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
freqs_cos: Optional[torch.Tensor] = None,
freqs_sin: Optional[torch.Tensor] = None,
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
out = {}
img = x
txt = text_states
_, _, ot, oh, ow = x.shape
tt, th, tw = (
ot // self.patch_size[0],
oh // self.patch_size[1],
ow // self.patch_size[2],
)
# Prepare modulation vectors.
vec = self.time_in(t)
# text modulation
vec = vec + self.vector_in(text_states_2)
# guidance modulation
if self.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
vec = vec + self.guidance_in(guidance)
# Embed image and text.
img = self.img_in(img)
if self.text_projection == "linear":
txt = self.txt_in(txt)
elif self.text_projection == "single_refiner":
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
else:
raise NotImplementedError(
f"Unsupported text_projection: {self.text_projection}"
)
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
# Compute cu_squlens and max_seqlen for flash attention
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
cu_seqlens_kv = cu_seqlens_q
max_seqlen_q = img_seq_len + txt_seq_len
max_seqlen_kv = max_seqlen_q
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
# --------------------- Pass through DiT blocks ------------------------
for _, block in enumerate(self.double_blocks):
double_block_args = [
img,
txt,
vec,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
freqs_cis,
]
img, txt = block(*double_block_args)
# Merge txt and img to pass through single stream blocks.
x = torch.cat((img, txt), 1)
if len(self.single_blocks) > 0:
for _, block in enumerate(self.single_blocks):
single_block_args = [
x,
vec,
txt_seq_len,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
(freqs_cos, freqs_sin),
]
x = block(*single_block_args)
img = x[:, :img_seq_len, ...]
# ---------------------------- Final layer ------------------------------
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
img = self.unpatchify(img, tt, th, tw)
if return_dict:
out["x"] = img
return out
return img
def unpatchify(self, x, t, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.unpatchify_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum("nthwcopq->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def params_count(self):
counts = {
"double": sum(
[
sum(p.numel() for p in block.img_attn_qkv.parameters())
+ sum(p.numel() for p in block.img_attn_proj.parameters())
+ sum(p.numel() for p in block.img_mlp.parameters())
+ sum(p.numel() for p in block.txt_attn_qkv.parameters())
+ sum(p.numel() for p in block.txt_attn_proj.parameters())
+ sum(p.numel() for p in block.txt_mlp.parameters())
for block in self.double_blocks
]
),
"single": sum(
[
sum(p.numel() for p in block.linear1.parameters())
+ sum(p.numel() for p in block.linear2.parameters())
for block in self.single_blocks
]
),
"total": sum(p.numel() for p in self.parameters()),
}
counts["attn+mlp"] = counts["double"] + counts["single"]
return counts
#################################################################################
# HunyuanVideo Configs #
#################################################################################
HUNYUAN_VIDEO_CONFIG = {
"HYVideo-T/2": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
},
"HYVideo-T/2-cfgdistill": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
"guidance_embed": True,
},
}
FastVideo-main/fastvideo/models/stepvideo/modules/normalization.py
0 → 100644
View file @
c07946d8
import
math
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
class
RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
elementwise_affine
=
True
,
eps
:
float
=
1e-6
,
device
=
None
,
dtype
=
None
,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
if
elementwise_affine
:
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
dim
,
**
factory_kwargs
))
def
_norm
(
self
,
x
):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
forward
(
self
,
x
):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output
=
self
.
_norm
(
x
.
float
()).
type_as
(
x
)
if
hasattr
(
self
,
"weight"
):
output
=
output
*
self
.
weight
return
output
ACTIVATION_FUNCTIONS
=
{
"swish"
:
nn
.
SiLU
(),
"silu"
:
nn
.
SiLU
(),
"mish"
:
nn
.
Mish
(),
"gelu"
:
nn
.
GELU
(),
"relu"
:
nn
.
ReLU
(),
}
def
get_activation
(
act_fn
:
str
)
->
nn
.
Module
:
"""Helper function to get activation function from string.
Args:
act_fn (str): Name of activation function.
Returns:
nn.Module: Activation function.
"""
act_fn
=
act_fn
.
lower
()
if
act_fn
in
ACTIVATION_FUNCTIONS
:
return
ACTIVATION_FUNCTIONS
[
act_fn
]
else
:
raise
ValueError
(
f
"Unsupported activation function:
{
act_fn
}
"
)
def
get_timestep_embedding
(
timesteps
:
torch
.
Tensor
,
embedding_dim
:
int
,
flip_sin_to_cos
:
bool
=
False
,
downscale_freq_shift
:
float
=
1
,
scale
:
float
=
1
,
max_period
:
int
=
10000
,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert
len
(
timesteps
.
shape
)
==
1
,
"Timesteps should be a 1d-array"
half_dim
=
embedding_dim
//
2
exponent
=
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half_dim
,
dtype
=
torch
.
float32
,
device
=
timesteps
.
device
)
exponent
=
exponent
/
(
half_dim
-
downscale_freq_shift
)
emb
=
torch
.
exp
(
exponent
)
emb
=
timesteps
[:,
None
].
float
()
*
emb
[
None
,
:]
# scale embeddings
emb
=
scale
*
emb
# concat sine and cosine embeddings
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=-
1
)
# flip sine and cosine embeddings
if
flip_sin_to_cos
:
emb
=
torch
.
cat
([
emb
[:,
half_dim
:],
emb
[:,
:
half_dim
]],
dim
=-
1
)
# zero pad
if
embedding_dim
%
2
==
1
:
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
class
Timesteps
(
nn
.
Module
):
def
__init__
(
self
,
num_channels
:
int
,
flip_sin_to_cos
:
bool
,
downscale_freq_shift
:
float
):
super
().
__init__
()
self
.
num_channels
=
num_channels
self
.
flip_sin_to_cos
=
flip_sin_to_cos
self
.
downscale_freq_shift
=
downscale_freq_shift
def
forward
(
self
,
timesteps
):
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
num_channels
,
flip_sin_to_cos
=
self
.
flip_sin_to_cos
,
downscale_freq_shift
=
self
.
downscale_freq_shift
,
)
return
t_emb
class
TimestepEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
time_embed_dim
:
int
,
act_fn
:
str
=
"silu"
,
out_dim
:
int
=
None
,
post_act_fn
:
Optional
[
str
]
=
None
,
cond_proj_dim
=
None
,
sample_proj_bias
=
True
):
super
().
__init__
()
linear_cls
=
nn
.
Linear
self
.
linear_1
=
linear_cls
(
in_channels
,
time_embed_dim
,
bias
=
sample_proj_bias
,
)
if
cond_proj_dim
is
not
None
:
self
.
cond_proj
=
linear_cls
(
cond_proj_dim
,
in_channels
,
bias
=
False
,
)
else
:
self
.
cond_proj
=
None
self
.
act
=
get_activation
(
act_fn
)
if
out_dim
is
not
None
:
time_embed_dim_out
=
out_dim
else
:
time_embed_dim_out
=
time_embed_dim
self
.
linear_2
=
linear_cls
(
time_embed_dim
,
time_embed_dim_out
,
bias
=
sample_proj_bias
,
)
if
post_act_fn
is
None
:
self
.
post_act
=
None
else
:
self
.
post_act
=
get_activation
(
post_act_fn
)
def
forward
(
self
,
sample
,
condition
=
None
):
if
condition
is
not
None
:
sample
=
sample
+
self
.
cond_proj
(
condition
)
sample
=
self
.
linear_1
(
sample
)
if
self
.
act
is
not
None
:
sample
=
self
.
act
(
sample
)
sample
=
self
.
linear_2
(
sample
)
if
self
.
post_act
is
not
None
:
sample
=
self
.
post_act
(
sample
)
return
sample
class
PixArtAlphaCombinedTimestepSizeEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
embedding_dim
,
size_emb_dim
,
use_additional_conditions
:
bool
=
False
):
super
().
__init__
()
self
.
outdim
=
size_emb_dim
self
.
time_proj
=
Timesteps
(
num_channels
=
256
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
self
.
timestep_embedder
=
TimestepEmbedding
(
in_channels
=
256
,
time_embed_dim
=
embedding_dim
)
self
.
use_additional_conditions
=
use_additional_conditions
if
self
.
use_additional_conditions
:
self
.
additional_condition_proj
=
Timesteps
(
num_channels
=
256
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
self
.
resolution_embedder
=
TimestepEmbedding
(
in_channels
=
256
,
time_embed_dim
=
size_emb_dim
)
self
.
nframe_embedder
=
TimestepEmbedding
(
in_channels
=
256
,
time_embed_dim
=
embedding_dim
)
self
.
fps_embedder
=
TimestepEmbedding
(
in_channels
=
256
,
time_embed_dim
=
embedding_dim
)
def
forward
(
self
,
timestep
,
resolution
=
None
,
nframe
=
None
,
fps
=
None
):
hidden_dtype
=
next
(
self
.
timestep_embedder
.
parameters
()).
dtype
timesteps_proj
=
self
.
time_proj
(
timestep
)
timesteps_emb
=
self
.
timestep_embedder
(
timesteps_proj
.
to
(
dtype
=
hidden_dtype
))
# (N, D)
if
self
.
use_additional_conditions
:
batch_size
=
timestep
.
shape
[
0
]
resolution_emb
=
self
.
additional_condition_proj
(
resolution
.
flatten
()).
to
(
hidden_dtype
)
resolution_emb
=
self
.
resolution_embedder
(
resolution_emb
).
reshape
(
batch_size
,
-
1
)
nframe_emb
=
self
.
additional_condition_proj
(
nframe
.
flatten
()).
to
(
hidden_dtype
)
nframe_emb
=
self
.
nframe_embedder
(
nframe_emb
).
reshape
(
batch_size
,
-
1
)
conditioning
=
timesteps_emb
+
resolution_emb
+
nframe_emb
if
fps
is
not
None
:
fps_emb
=
self
.
additional_condition_proj
(
fps
.
flatten
()).
to
(
hidden_dtype
)
fps_emb
=
self
.
fps_embedder
(
fps_emb
).
reshape
(
batch_size
,
-
1
)
conditioning
=
conditioning
+
fps_emb
else
:
conditioning
=
timesteps_emb
return
conditioning
class
AdaLayerNormSingle
(
nn
.
Module
):
r
"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def
__init__
(
self
,
embedding_dim
:
int
,
use_additional_conditions
:
bool
=
False
,
time_step_rescale
=
1000
):
super
().
__init__
()
self
.
emb
=
PixArtAlphaCombinedTimestepSizeEmbeddings
(
embedding_dim
,
size_emb_dim
=
embedding_dim
//
2
,
use_additional_conditions
=
use_additional_conditions
)
self
.
silu
=
nn
.
SiLU
()
self
.
linear
=
nn
.
Linear
(
embedding_dim
,
6
*
embedding_dim
,
bias
=
True
)
self
.
time_step_rescale
=
time_step_rescale
## timestep usually in [0, 1], we rescale it to [0,1000] for stability
def
forward
(
self
,
timestep
:
torch
.
Tensor
,
added_cond_kwargs
:
Dict
[
str
,
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
embedded_timestep
=
self
.
emb
(
timestep
*
self
.
time_step_rescale
,
**
added_cond_kwargs
)
out
=
self
.
linear
(
self
.
silu
(
embedded_timestep
))
return
out
,
embedded_timestep
class
PixArtAlphaTextProjection
(
nn
.
Module
):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def
__init__
(
self
,
in_features
,
hidden_size
):
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
in_features
,
hidden_size
,
bias
=
True
,
)
self
.
act_1
=
nn
.
GELU
(
approximate
=
"tanh"
)
self
.
linear_2
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
True
,
)
def
forward
(
self
,
caption
):
hidden_states
=
self
.
linear_1
(
caption
)
hidden_states
=
self
.
act_1
(
hidden_states
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
return
hidden_states
FastVideo-main/fastvideo/models/stepvideo/modules/rope.py
0 → 100644
View file @
c07946d8
import
torch
from
fastvideo.utils.parallel_states
import
nccl_info
class
RoPE1D
:
def
__init__
(
self
,
freq
=
1e4
,
F0
=
1.0
,
scaling_factor
=
1.0
):
self
.
base
=
freq
self
.
F0
=
F0
self
.
scaling_factor
=
scaling_factor
self
.
cache
=
{}
def
get_cos_sin
(
self
,
D
,
seq_len
,
device
,
dtype
):
if
(
D
,
seq_len
,
device
,
dtype
)
not
in
self
.
cache
:
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
D
,
2
).
float
().
to
(
device
)
/
D
))
t
=
torch
.
arange
(
seq_len
,
device
=
device
,
dtype
=
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
inv_freq
).
to
(
dtype
)
freqs
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
cos
=
freqs
.
cos
()
# (Seq, Dim)
sin
=
freqs
.
sin
()
self
.
cache
[
D
,
seq_len
,
device
,
dtype
]
=
(
cos
,
sin
)
return
self
.
cache
[
D
,
seq_len
,
device
,
dtype
]
@
staticmethod
def
rotate_half
(
x
):
x1
,
x2
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
],
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rope1d
(
self
,
tokens
,
pos1d
,
cos
,
sin
):
assert
pos1d
.
ndim
==
2
cos
=
torch
.
nn
.
functional
.
embedding
(
pos1d
,
cos
)[:,
:,
None
,
:]
sin
=
torch
.
nn
.
functional
.
embedding
(
pos1d
,
sin
)[:,
:,
None
,
:]
return
(
tokens
*
cos
)
+
(
self
.
rotate_half
(
tokens
)
*
sin
)
def
__call__
(
self
,
tokens
,
positions
):
"""
input:
* tokens: batch_size x ntokens x nheads x dim
* positions: batch_size x ntokens (t position of each token)
output:
* tokens after applying RoPE2D (batch_size x ntokens x nheads x dim)
"""
D
=
tokens
.
size
(
3
)
assert
positions
.
ndim
==
2
# Batch, Seq
cos
,
sin
=
self
.
get_cos_sin
(
D
,
int
(
positions
.
max
())
+
1
,
tokens
.
device
,
tokens
.
dtype
)
tokens
=
self
.
apply_rope1d
(
tokens
,
positions
,
cos
,
sin
)
return
tokens
class
RoPE3D
(
RoPE1D
):
def
__init__
(
self
,
freq
=
1e4
,
F0
=
1.0
,
scaling_factor
=
1.0
):
super
(
RoPE3D
,
self
).
__init__
(
freq
,
F0
,
scaling_factor
)
self
.
position_cache
=
{}
def
get_mesh_3d
(
self
,
rope_positions
,
bsz
):
f
,
h
,
w
=
rope_positions
if
f
"
{
f
}
-
{
h
}
-
{
w
}
"
not
in
self
.
position_cache
:
x
=
torch
.
arange
(
f
,
device
=
'cpu'
)
y
=
torch
.
arange
(
h
,
device
=
'cpu'
)
z
=
torch
.
arange
(
w
,
device
=
'cpu'
)
self
.
position_cache
[
f
"
{
f
}
-
{
h
}
-
{
w
}
"
]
=
torch
.
cartesian_prod
(
x
,
y
,
z
).
view
(
1
,
f
*
h
*
w
,
3
).
expand
(
bsz
,
-
1
,
3
)
return
self
.
position_cache
[
f
"
{
f
}
-
{
h
}
-
{
w
}
"
]
def
__call__
(
self
,
tokens
,
rope_positions
,
ch_split
,
parallel
=
False
):
"""
input:
* tokens: batch_size x ntokens x nheads x dim
* rope_positions: list of (f, h, w)
output:
* tokens after applying RoPE2D (batch_size x ntokens x nheads x dim)
"""
assert
sum
(
ch_split
)
==
tokens
.
size
(
-
1
)
mesh_grid
=
self
.
get_mesh_3d
(
rope_positions
,
bsz
=
tokens
.
shape
[
0
])
out
=
[]
for
i
,
(
D
,
x
)
in
enumerate
(
zip
(
ch_split
,
torch
.
split
(
tokens
,
ch_split
,
dim
=-
1
))):
cos
,
sin
=
self
.
get_cos_sin
(
D
,
int
(
mesh_grid
.
max
())
+
1
,
tokens
.
device
,
tokens
.
dtype
)
if
parallel
:
mesh
=
torch
.
chunk
(
mesh_grid
[:,
:,
i
],
nccl_info
.
sp_size
,
dim
=
1
)[
nccl_info
.
rank_within_group
].
clone
()
else
:
mesh
=
mesh_grid
[:,
:,
i
].
clone
()
x
=
self
.
apply_rope1d
(
x
,
mesh
.
to
(
tokens
.
device
),
cos
,
sin
)
out
.
append
(
x
)
tokens
=
torch
.
cat
(
out
,
dim
=-
1
)
return
tokens
FastVideo-main/fastvideo/models/stepvideo/parallel.py
0 → 100644
View file @
c07946d8
import
torch
from
fastvideo.utils.communications
import
all_gather
from
fastvideo.utils.parallel_states
import
nccl_info
def
parallel_forward
(
fn_
):
def
wrapTheFunction
(
_
,
hidden_states
,
*
args
,
**
kwargs
):
if
kwargs
[
'parallel'
]:
hidden_states
=
torch
.
chunk
(
hidden_states
,
nccl_info
.
sp_size
,
dim
=-
2
)[
nccl_info
.
rank_within_group
]
kwargs
[
'attn_mask'
]
=
torch
.
chunk
(
kwargs
[
'attn_mask'
],
nccl_info
.
sp_size
,
dim
=-
2
)[
nccl_info
.
rank_within_group
]
output
=
fn_
(
_
,
hidden_states
,
*
args
,
**
kwargs
)
if
kwargs
[
'parallel'
]:
output
=
all_gather
(
output
.
contiguous
(),
dim
=-
2
)
return
output
return
wrapTheFunction
FastVideo-main/fastvideo/models/stepvideo/text_encoder/__init__.py
0 → 100644
View file @
c07946d8
import
os
import
torch
from
fastvideo.models.stepvideo.config
import
parse_args
try
:
args
=
parse_args
()
torch
.
ops
.
load_library
(
os
.
path
.
join
(
args
.
model_dir
,
'lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so'
))
except
Exception
as
err
:
print
(
err
)
FastVideo-main/fastvideo/models/stepvideo/text_encoder/clip.py
0 → 100644
View file @
c07946d8
import
os
import
torch
import
torch.nn
as
nn
from
transformers
import
BertModel
,
BertTokenizer
class
HunyuanClip
(
nn
.
Module
):
"""
Hunyuan clip code copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
hunyuan's clip used BertModel and BertTokenizer, so we copy it.
"""
def
__init__
(
self
,
model_dir
,
max_length
=
77
):
super
(
HunyuanClip
,
self
).
__init__
()
self
.
max_length
=
max_length
self
.
tokenizer
=
BertTokenizer
.
from_pretrained
(
os
.
path
.
join
(
model_dir
,
'tokenizer'
))
self
.
text_encoder
=
BertModel
.
from_pretrained
(
os
.
path
.
join
(
model_dir
,
'clip_text_encoder'
))
@
torch
.
no_grad
def
forward
(
self
,
prompts
,
with_mask
=
True
):
self
.
device
=
next
(
self
.
text_encoder
.
parameters
()).
device
text_inputs
=
self
.
tokenizer
(
prompts
,
padding
=
"max_length"
,
max_length
=
self
.
max_length
,
truncation
=
True
,
return_attention_mask
=
True
,
return_tensors
=
"pt"
,
)
prompt_embeds
=
self
.
text_encoder
(
text_inputs
.
input_ids
.
to
(
self
.
device
),
attention_mask
=
text_inputs
.
attention_mask
.
to
(
self
.
device
)
if
with_mask
else
None
,
)
return
prompt_embeds
.
last_hidden_state
,
prompt_embeds
.
pooler_output
FastVideo-main/fastvideo/models/stepvideo/text_encoder/flashattention.py
0 → 100644
View file @
c07946d8
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import
torch
def
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
True
,
return_attn_probs
=
False
,
tp_group_rank
=
0
,
tp_group_size
=
1
):
softmax_scale
=
q
.
size
(
-
1
)
**
(
-
0.5
)
if
softmax_scale
is
None
else
softmax_scale
return
torch
.
ops
.
Optimus
.
fwd
(
q
,
k
,
v
,
None
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
,
None
,
tp_group_rank
,
tp_group_size
)[
0
]
class
FlashSelfAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
attention_dropout
=
0.0
,
):
super
().
__init__
()
self
.
dropout_p
=
attention_dropout
def
forward
(
self
,
q
,
k
,
v
,
cu_seqlens
=
None
,
max_seq_len
=
None
):
if
cu_seqlens
is
None
:
output
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
self
.
dropout_p
)
else
:
raise
ValueError
(
'cu_seqlens is not supported!'
)
return
output
FastVideo-main/fastvideo/models/stepvideo/text_encoder/stepllm.py
0 → 100644
View file @
c07946d8
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import
os
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers.modeling_utils
import
PretrainedConfig
,
PreTrainedModel
from
fastvideo.models.stepvideo.modules.normalization
import
RMSNorm
from
fastvideo.models.stepvideo.text_encoder.flashattention
import
FlashSelfAttention
from
fastvideo.models.stepvideo.text_encoder.tokenizer
import
LLaMaEmbedding
,
Wrapped_StepChatTokenizer
from
fastvideo.models.stepvideo.utils
import
with_empty_init
def
safediv
(
n
,
d
):
q
,
r
=
divmod
(
n
,
d
)
assert
r
==
0
return
q
class
MultiQueryAttention
(
nn
.
Module
):
def
__init__
(
self
,
cfg
,
layer_id
=
None
):
super
().
__init__
()
self
.
head_dim
=
cfg
.
hidden_size
//
cfg
.
num_attention_heads
self
.
max_seq_len
=
cfg
.
seq_length
self
.
use_flash_attention
=
cfg
.
use_flash_attn
assert
self
.
use_flash_attention
,
'FlashAttention is required!'
self
.
n_groups
=
cfg
.
num_attention_groups
self
.
tp_size
=
1
self
.
n_local_heads
=
cfg
.
num_attention_heads
self
.
n_local_groups
=
self
.
n_groups
self
.
wqkv
=
nn
.
Linear
(
cfg
.
hidden_size
,
cfg
.
hidden_size
+
self
.
head_dim
*
2
*
self
.
n_groups
,
bias
=
False
,
)
self
.
wo
=
nn
.
Linear
(
cfg
.
hidden_size
,
cfg
.
hidden_size
,
bias
=
False
,
)
assert
self
.
use_flash_attention
,
'non-Flash attention not supported yet.'
self
.
core_attention
=
FlashSelfAttention
(
attention_dropout
=
cfg
.
attention_dropout
)
self
.
layer_id
=
layer_id
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
],
cu_seqlens
:
Optional
[
torch
.
Tensor
],
max_seq_len
:
Optional
[
torch
.
Tensor
],
):
seqlen
,
bsz
,
dim
=
x
.
shape
xqkv
=
self
.
wqkv
(
x
)
xq
,
xkv
=
torch
.
split
(
xqkv
,
(
dim
//
self
.
tp_size
,
self
.
head_dim
*
2
*
self
.
n_groups
//
self
.
tp_size
),
dim
=-
1
,
)
# gather on 1st dimension
xq
=
xq
.
view
(
seqlen
,
bsz
,
self
.
n_local_heads
,
self
.
head_dim
)
xkv
=
xkv
.
view
(
seqlen
,
bsz
,
self
.
n_local_groups
,
2
*
self
.
head_dim
)
xk
,
xv
=
xkv
.
chunk
(
2
,
-
1
)
# rotary embedding + flash attn
xq
=
rearrange
(
xq
,
"s b h d -> b s h d"
)
xk
=
rearrange
(
xk
,
"s b h d -> b s h d"
)
xv
=
rearrange
(
xv
,
"s b h d -> b s h d"
)
q_per_kv
=
self
.
n_local_heads
//
self
.
n_local_groups
if
q_per_kv
>
1
:
b
,
s
,
h
,
d
=
xk
.
size
()
if
h
==
1
:
xk
=
xk
.
expand
(
b
,
s
,
q_per_kv
,
d
)
xv
=
xv
.
expand
(
b
,
s
,
q_per_kv
,
d
)
else
:
''' To cover the cases where h > 1, we have
the following implementation, which is equivalent to:
xk = xk.repeat_interleave(q_per_kv, dim=-2)
xv = xv.repeat_interleave(q_per_kv, dim=-2)
but can avoid calling aten::item() that involves cpu.
'''
idx
=
torch
.
arange
(
q_per_kv
*
h
,
device
=
xk
.
device
).
reshape
(
q_per_kv
,
-
1
).
permute
(
1
,
0
).
flatten
()
xk
=
torch
.
index_select
(
xk
.
repeat
(
1
,
1
,
q_per_kv
,
1
),
2
,
idx
).
contiguous
()
xv
=
torch
.
index_select
(
xv
.
repeat
(
1
,
1
,
q_per_kv
,
1
),
2
,
idx
).
contiguous
()
if
self
.
use_flash_attention
:
output
=
self
.
core_attention
(
xq
,
xk
,
xv
,
cu_seqlens
=
cu_seqlens
,
max_seq_len
=
max_seq_len
)
# reduce-scatter only support first dimension now
output
=
rearrange
(
output
,
"b s h d -> s b (h d)"
).
contiguous
()
else
:
xq
,
xk
,
xv
=
[
rearrange
(
x
,
"b s ... -> s b ..."
).
contiguous
()
for
x
in
(
xq
,
xk
,
xv
)]
output
=
self
.
core_attention
(
xq
,
xk
,
xv
,
mask
)
output
=
self
.
wo
(
output
)
return
output
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
cfg
,
dim
:
int
,
hidden_dim
:
int
,
layer_id
:
int
,
multiple_of
:
int
=
256
,
):
super
().
__init__
()
hidden_dim
=
multiple_of
*
((
hidden_dim
+
multiple_of
-
1
)
//
multiple_of
)
def
swiglu
(
x
):
x
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
return
F
.
silu
(
x
[
0
])
*
x
[
1
]
self
.
swiglu
=
swiglu
self
.
w1
=
nn
.
Linear
(
dim
,
2
*
hidden_dim
,
bias
=
False
,
)
self
.
w2
=
nn
.
Linear
(
hidden_dim
,
dim
,
bias
=
False
,
)
def
forward
(
self
,
x
):
x
=
self
.
swiglu
(
self
.
w1
(
x
))
output
=
self
.
w2
(
x
)
return
output
class
TransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
cfg
,
layer_id
:
int
):
super
().
__init__
()
self
.
n_heads
=
cfg
.
num_attention_heads
self
.
dim
=
cfg
.
hidden_size
self
.
head_dim
=
cfg
.
hidden_size
//
cfg
.
num_attention_heads
self
.
attention
=
MultiQueryAttention
(
cfg
,
layer_id
=
layer_id
,
)
self
.
feed_forward
=
FeedForward
(
cfg
,
dim
=
cfg
.
hidden_size
,
hidden_dim
=
cfg
.
ffn_hidden_size
,
layer_id
=
layer_id
,
)
self
.
layer_id
=
layer_id
self
.
attention_norm
=
RMSNorm
(
cfg
.
hidden_size
,
eps
=
cfg
.
layernorm_epsilon
,
)
self
.
ffn_norm
=
RMSNorm
(
cfg
.
hidden_size
,
eps
=
cfg
.
layernorm_epsilon
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
],
cu_seqlens
:
Optional
[
torch
.
Tensor
],
max_seq_len
:
Optional
[
torch
.
Tensor
],
):
residual
=
self
.
attention
.
forward
(
self
.
attention_norm
(
x
),
mask
,
cu_seqlens
,
max_seq_len
)
h
=
x
+
residual
ffn_res
=
self
.
feed_forward
.
forward
(
self
.
ffn_norm
(
h
))
out
=
h
+
ffn_res
return
out
class
Transformer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
max_seq_size
=
8192
,
):
super
().
__init__
()
self
.
num_layers
=
config
.
num_layers
self
.
layers
=
self
.
_build_layers
(
config
)
def
_build_layers
(
self
,
config
):
layers
=
torch
.
nn
.
ModuleList
()
for
layer_id
in
range
(
self
.
num_layers
):
layers
.
append
(
TransformerBlock
(
config
,
layer_id
=
layer_id
+
1
,
))
return
layers
def
forward
(
self
,
hidden_states
,
attention_mask
,
cu_seqlens
=
None
,
max_seq_len
=
None
,
):
if
max_seq_len
is
not
None
and
not
isinstance
(
max_seq_len
,
torch
.
Tensor
):
max_seq_len
=
torch
.
tensor
(
max_seq_len
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
for
lid
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
cu_seqlens
,
max_seq_len
,
)
return
hidden_states
class
Step1Model
(
PreTrainedModel
):
config_class
=
PretrainedConfig
@
with_empty_init
def
__init__
(
self
,
config
,
):
super
().
__init__
(
config
)
self
.
tok_embeddings
=
LLaMaEmbedding
(
config
)
self
.
transformer
=
Transformer
(
config
)
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
):
hidden_states
=
self
.
tok_embeddings
(
input_ids
)
hidden_states
=
self
.
transformer
(
hidden_states
,
attention_mask
,
)
return
hidden_states
class
STEP1TextEncoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_dir
,
max_length
=
320
):
super
(
STEP1TextEncoder
,
self
).
__init__
()
self
.
max_length
=
max_length
self
.
text_tokenizer
=
Wrapped_StepChatTokenizer
(
os
.
path
.
join
(
model_dir
,
'step1_chat_tokenizer.model'
))
text_encoder
=
Step1Model
.
from_pretrained
(
model_dir
)
self
.
text_encoder
=
text_encoder
.
eval
().
to
(
torch
.
bfloat16
)
@
torch
.
no_grad
def
forward
(
self
,
prompts
,
with_mask
=
True
,
max_length
=
None
):
self
.
device
=
next
(
self
.
text_encoder
.
parameters
()).
device
with
torch
.
no_grad
(),
torch
.
cuda
.
amp
.
autocast
(
dtype
=
torch
.
bfloat16
):
if
type
(
prompts
)
is
str
:
prompts
=
[
prompts
]
txt_tokens
=
self
.
text_tokenizer
(
prompts
,
max_length
=
max_length
or
self
.
max_length
,
padding
=
"max_length"
,
truncation
=
True
,
return_tensors
=
"pt"
)
y
=
self
.
text_encoder
(
txt_tokens
.
input_ids
.
to
(
self
.
device
),
attention_mask
=
txt_tokens
.
attention_mask
.
to
(
self
.
device
)
if
with_mask
else
None
)
y_mask
=
txt_tokens
.
attention_mask
return
y
.
transpose
(
0
,
1
),
y_mask
FastVideo-main/fastvideo/models/stepvideo/text_encoder/tokenizer.py
0 → 100644
View file @
c07946d8
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
from
typing
import
List
import
torch
import
torch.nn
as
nn
class
LLaMaEmbedding
(
nn
.
Module
):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def
__init__
(
self
,
cfg
,
):
super
().
__init__
()
self
.
hidden_size
=
cfg
.
hidden_size
self
.
params_dtype
=
cfg
.
params_dtype
self
.
fp32_residual_connection
=
cfg
.
fp32_residual_connection
self
.
embedding_weights_in_fp32
=
cfg
.
embedding_weights_in_fp32
self
.
word_embeddings
=
torch
.
nn
.
Embedding
(
cfg
.
padded_vocab_size
,
self
.
hidden_size
,
)
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
cfg
.
hidden_dropout
)
def
forward
(
self
,
input_ids
):
# Embeddings.
if
self
.
embedding_weights_in_fp32
:
self
.
word_embeddings
=
self
.
word_embeddings
.
to
(
torch
.
float32
)
embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
embedding_weights_in_fp32
:
embeddings
=
embeddings
.
to
(
self
.
params_dtype
)
self
.
word_embeddings
=
self
.
word_embeddings
.
to
(
self
.
params_dtype
)
# Data format change to avoid explicit transposes : [b s h] --> [s b h].
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
()
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
embeddings
=
embeddings
.
float
()
# Dropout.
embeddings
=
self
.
embedding_dropout
(
embeddings
)
return
embeddings
class
StepChatTokenizer
:
"""Step Chat Tokenizer"""
def
__init__
(
self
,
model_file
,
name
=
"StepChatTokenizer"
,
bot_token
=
"<|BOT|>"
,
# Begin of Turn
eot_token
=
"<|EOT|>"
,
# End of Turn
call_start_token
=
"<|CALL_START|>"
,
# Call Start
call_end_token
=
"<|CALL_END|>"
,
# Call End
think_start_token
=
"<|THINK_START|>"
,
# Think Start
think_end_token
=
"<|THINK_END|>"
,
# Think End
mask_start_token
=
"<|MASK_1e69f|>"
,
# Mask start
mask_end_token
=
"<|UNMASK_1e69f|>"
,
# Mask end
):
import
sentencepiece
self
.
_tokenizer
=
sentencepiece
.
SentencePieceProcessor
(
model_file
=
model_file
)
self
.
_vocab
=
{}
self
.
_inv_vocab
=
{}
self
.
_special_tokens
=
{}
self
.
_inv_special_tokens
=
{}
self
.
_t5_tokens
=
[]
for
idx
in
range
(
self
.
_tokenizer
.
get_piece_size
()):
text
=
self
.
_tokenizer
.
id_to_piece
(
idx
)
self
.
_inv_vocab
[
idx
]
=
text
self
.
_vocab
[
text
]
=
idx
if
self
.
_tokenizer
.
is_control
(
idx
)
or
self
.
_tokenizer
.
is_unknown
(
idx
):
self
.
_special_tokens
[
text
]
=
idx
self
.
_inv_special_tokens
[
idx
]
=
text
self
.
_unk_id
=
self
.
_tokenizer
.
unk_id
()
self
.
_bos_id
=
self
.
_tokenizer
.
bos_id
()
self
.
_eos_id
=
self
.
_tokenizer
.
eos_id
()
for
token
in
[
bot_token
,
eot_token
,
call_start_token
,
call_end_token
,
think_start_token
,
think_end_token
]:
assert
token
in
self
.
_vocab
,
f
"Token '
{
token
}
' not found in tokenizer"
assert
token
in
self
.
_special_tokens
,
f
"Token '
{
token
}
' is not a special token"
for
token
in
[
mask_start_token
,
mask_end_token
]:
assert
token
in
self
.
_vocab
,
f
"Token '
{
token
}
' not found in tokenizer"
self
.
_bot_id
=
self
.
_tokenizer
.
piece_to_id
(
bot_token
)
self
.
_eot_id
=
self
.
_tokenizer
.
piece_to_id
(
eot_token
)
self
.
_call_start_id
=
self
.
_tokenizer
.
piece_to_id
(
call_start_token
)
self
.
_call_end_id
=
self
.
_tokenizer
.
piece_to_id
(
call_end_token
)
self
.
_think_start_id
=
self
.
_tokenizer
.
piece_to_id
(
think_start_token
)
self
.
_think_end_id
=
self
.
_tokenizer
.
piece_to_id
(
think_end_token
)
self
.
_mask_start_id
=
self
.
_tokenizer
.
piece_to_id
(
mask_start_token
)
self
.
_mask_end_id
=
self
.
_tokenizer
.
piece_to_id
(
mask_end_token
)
self
.
_underline_id
=
self
.
_tokenizer
.
piece_to_id
(
"
\u2581
"
)
@
property
def
vocab
(
self
):
return
self
.
_vocab
@
property
def
inv_vocab
(
self
):
return
self
.
_inv_vocab
@
property
def
vocab_size
(
self
):
return
self
.
_tokenizer
.
vocab_size
()
def
tokenize
(
self
,
text
:
str
)
->
List
[
int
]:
return
self
.
_tokenizer
.
encode_as_ids
(
text
)
def
detokenize
(
self
,
token_ids
:
List
[
int
])
->
str
:
return
self
.
_tokenizer
.
decode_ids
(
token_ids
)
class
Tokens
:
def
__init__
(
self
,
input_ids
,
cu_input_ids
,
attention_mask
,
cu_seqlens
,
max_seq_len
)
->
None
:
self
.
input_ids
=
input_ids
self
.
attention_mask
=
attention_mask
self
.
cu_input_ids
=
cu_input_ids
self
.
cu_seqlens
=
cu_seqlens
self
.
max_seq_len
=
max_seq_len
def
to
(
self
,
device
):
self
.
input_ids
=
self
.
input_ids
.
to
(
device
)
self
.
attention_mask
=
self
.
attention_mask
.
to
(
device
)
self
.
cu_input_ids
=
self
.
cu_input_ids
.
to
(
device
)
self
.
cu_seqlens
=
self
.
cu_seqlens
.
to
(
device
)
return
self
class
Wrapped_StepChatTokenizer
(
StepChatTokenizer
):
def
__call__
(
self
,
text
,
max_length
=
320
,
padding
=
"max_length"
,
truncation
=
True
,
return_tensors
=
"pt"
):
# [bos, ..., eos, pad, pad, ..., pad]
self
.
BOS
=
1
self
.
EOS
=
2
self
.
PAD
=
2
out_tokens
=
[]
attn_mask
=
[]
if
len
(
text
)
==
0
:
part_tokens
=
[
self
.
BOS
]
+
[
self
.
EOS
]
valid_size
=
len
(
part_tokens
)
if
len
(
part_tokens
)
<
max_length
:
part_tokens
+=
[
self
.
PAD
]
*
(
max_length
-
valid_size
)
out_tokens
.
append
(
part_tokens
)
attn_mask
.
append
([
1
]
*
valid_size
+
[
0
]
*
(
max_length
-
valid_size
))
else
:
for
part
in
text
:
part_tokens
=
self
.
tokenize
(
part
)
part_tokens
=
part_tokens
[:(
max_length
-
2
)]
# leave 2 space for bos and eos
part_tokens
=
[
self
.
BOS
]
+
part_tokens
+
[
self
.
EOS
]
valid_size
=
len
(
part_tokens
)
if
len
(
part_tokens
)
<
max_length
:
part_tokens
+=
[
self
.
PAD
]
*
(
max_length
-
valid_size
)
out_tokens
.
append
(
part_tokens
)
attn_mask
.
append
([
1
]
*
valid_size
+
[
0
]
*
(
max_length
-
valid_size
))
out_tokens
=
torch
.
tensor
(
out_tokens
,
dtype
=
torch
.
long
)
attn_mask
=
torch
.
tensor
(
attn_mask
,
dtype
=
torch
.
long
)
# padding y based on tp size
padded_len
=
0
padded_flag
=
True
if
padded_len
>
0
else
False
if
padded_flag
:
pad_tokens
=
torch
.
tensor
([[
self
.
PAD
]
*
max_length
],
device
=
out_tokens
.
device
)
pad_attn_mask
=
torch
.
tensor
([[
1
]
*
padded_len
+
[
0
]
*
(
max_length
-
padded_len
)],
device
=
attn_mask
.
device
)
out_tokens
=
torch
.
cat
([
out_tokens
,
pad_tokens
],
dim
=
0
)
attn_mask
=
torch
.
cat
([
attn_mask
,
pad_attn_mask
],
dim
=
0
)
# cu_seqlens
cu_out_tokens
=
out_tokens
.
masked_select
(
attn_mask
!=
0
).
unsqueeze
(
0
)
seqlen
=
attn_mask
.
sum
(
dim
=
1
).
tolist
()
cu_seqlens
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
seqlen
),
0
).
to
(
device
=
out_tokens
.
device
,
dtype
=
torch
.
int32
)
max_seq_len
=
max
(
seqlen
)
return
Tokens
(
out_tokens
,
cu_out_tokens
,
attn_mask
,
cu_seqlens
,
max_seq_len
)
FastVideo-main/fastvideo/models/stepvideo/utils/__init__.py
0 → 100644
View file @
c07946d8
from
.utils
import
*
from
.video_process
import
*
\ No newline at end of file
FastVideo-main/fastvideo/models/stepvideo/utils/quantization.py
0 → 100644
View file @
c07946d8
# from stepvideo.diffusion.video_pipeline import StepVideoPipeline
import
torch
import
torch.nn
as
nn
from
torch.nn
import
functional
as
F
def
get_fp_maxval
(
bits
=
8
,
mantissa_bit
=
3
,
sign_bits
=
1
):
_bits
=
torch
.
tensor
(
bits
)
_mantissa_bit
=
torch
.
tensor
(
mantissa_bit
)
_sign_bits
=
torch
.
tensor
(
sign_bits
)
M
=
torch
.
clamp
(
torch
.
round
(
_mantissa_bit
),
1
,
_bits
-
_sign_bits
)
E
=
_bits
-
_sign_bits
-
M
bias
=
2
**
(
E
-
1
)
-
1
mantissa
=
1
for
i
in
range
(
mantissa_bit
-
1
):
mantissa
+=
1
/
(
2
**
(
i
+
1
))
maxval
=
mantissa
*
2
**
(
2
**
E
-
1
-
bias
)
return
maxval
def
quantize_to_fp8
(
x
,
bits
=
8
,
mantissa_bit
=
3
,
sign_bits
=
1
):
"""
Default is E4M3.
"""
bits
=
torch
.
tensor
(
bits
)
mantissa_bit
=
torch
.
tensor
(
mantissa_bit
)
sign_bits
=
torch
.
tensor
(
sign_bits
)
M
=
torch
.
clamp
(
torch
.
round
(
mantissa_bit
),
1
,
bits
-
sign_bits
)
E
=
bits
-
sign_bits
-
M
bias
=
2
**
(
E
-
1
)
-
1
mantissa
=
1
for
i
in
range
(
mantissa_bit
-
1
):
mantissa
+=
1
/
(
2
**
(
i
+
1
))
maxval
=
mantissa
*
2
**
(
2
**
E
-
1
-
bias
)
minval
=
-
maxval
minval
=
-
maxval
if
sign_bits
==
1
else
torch
.
zeros_like
(
maxval
)
input_clamp
=
torch
.
min
(
torch
.
max
(
x
,
minval
),
maxval
)
log_scales
=
torch
.
clamp
((
torch
.
floor
(
torch
.
log2
(
torch
.
abs
(
input_clamp
))
+
bias
)).
detach
(),
1.0
)
log_scales
=
2.0
**
(
log_scales
-
M
-
bias
.
type
(
x
.
dtype
))
# dequant
qdq_out
=
torch
.
round
(
input_clamp
/
log_scales
)
*
log_scales
return
qdq_out
,
log_scales
def
fp8_tensor_quant
(
x
,
scale
,
bits
=
8
,
mantissa_bit
=
3
,
sign_bits
=
1
):
for
i
in
range
(
len
(
x
.
shape
)
-
1
):
scale
=
scale
.
unsqueeze
(
-
1
)
new_x
=
x
/
scale
quant_dequant_x
,
log_scales
=
quantize_to_fp8
(
new_x
,
bits
=
bits
,
mantissa_bit
=
mantissa_bit
,
sign_bits
=
sign_bits
)
return
quant_dequant_x
,
scale
,
log_scales
def
fp8_activation_dequant
(
qdq_out
,
scale
,
dtype
):
qdq_out
=
qdq_out
.
type
(
dtype
)
quant_dequant_x
=
qdq_out
*
scale
.
to
(
dtype
)
return
quant_dequant_x
def
fp8_linear_forward
(
cls
,
original_dtype
,
input
):
weight_dtype
=
cls
.
weight
.
dtype
#####
if
cls
.
weight
.
dtype
!=
torch
.
float8_e4m3fn
:
assert
False
maxval
=
get_fp_maxval
()
scale
=
torch
.
max
(
torch
.
abs
(
cls
.
weight
.
flatten
()))
/
maxval
linear_weight
,
scale
,
log_scales
=
fp8_tensor_quant
(
cls
.
weight
,
scale
)
linear_weight
=
linear_weight
.
to
(
torch
.
float8_e4m3fn
)
weight_dtype
=
linear_weight
.
dtype
else
:
scale
=
cls
.
fp8_scale
.
to
(
cls
.
weight
.
device
)
linear_weight
=
cls
.
weight
#####
if
weight_dtype
==
torch
.
float8_e4m3fn
:
if
True
or
len
(
input
.
shape
)
==
3
:
cls_dequant
=
fp8_activation_dequant
(
linear_weight
,
scale
,
original_dtype
)
if
cls
.
bias
is
not
None
:
print
(
f
"input dtype:
{
input
.
dtype
}
"
)
print
(
f
"cls_dequant dtype:
{
cls_dequant
.
dtype
}
"
)
print
(
f
"cls.bias dtype:
{
cls
.
bias
.
dtype
}
"
)
output
=
F
.
linear
(
input
,
cls_dequant
,
cls
.
bias
)
else
:
output
=
F
.
linear
(
input
,
cls_dequant
)
return
output
else
:
return
cls
.
original_forward
(
input
.
to
(
original_dtype
))
else
:
return
cls
.
original_forward
(
input
)
def
convert_fp8_linear
(
module
,
original_dtype
,
params_to_keep
=
{}):
setattr
(
module
,
"fp8_matmul_enabled"
,
True
)
fp8_layers
=
[]
scale_dict
=
{}
counter
=
0
for
key
,
layer
in
module
.
named_modules
():
if
isinstance
(
layer
,
nn
.
Linear
)
and
'transformer_blocks'
in
key
:
print
(
f
"Converting
{
key
}
to FP8"
)
fp8_layers
.
append
(
key
)
original_forward
=
layer
.
forward
maxval
=
get_fp_maxval
()
scale
=
torch
.
max
(
torch
.
abs
(
layer
.
weight
.
flatten
()))
/
maxval
original_weight
=
layer
.
weight
.
data
# Store a reference to the original weights
quantized_weight
,
scale
,
_
=
fp8_tensor_quant
(
original_weight
,
scale
)
scale_dict
[
key
]
=
scale
layer
.
weight
=
torch
.
nn
.
Parameter
(
quantized_weight
.
to
(
torch
.
float8_e4m3fn
))
del
original_weight
# Delete the reference to the original weights
torch
.
cuda
.
empty_cache
()
# print(f"layer weight dtype: {layer.weight.dtype} for layer {key}")
setattr
(
layer
,
"fp8_scale"
,
scale
.
to
(
dtype
=
original_dtype
))
setattr
(
layer
,
"original_forward"
,
original_forward
)
setattr
(
layer
,
"forward"
,
lambda
input
,
m
=
layer
:
fp8_linear_forward
(
m
,
original_dtype
,
input
))
counter
+=
1
return
scale_dict
FastVideo-main/fastvideo/models/stepvideo/utils/utils.py
0 → 100644
View file @
c07946d8
import
random
from
functools
import
wraps
import
numpy
as
np
import
torch
import
torch.utils._device
def
setup_seed
(
seed
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
class
EmptyInitOnDevice
(
torch
.
overrides
.
TorchFunctionMode
):
def
__init__
(
self
,
device
=
None
):
self
.
device
=
device
def
__torch_function__
(
self
,
func
,
types
,
args
=
(),
kwargs
=
None
):
kwargs
=
kwargs
or
{}
if
getattr
(
func
,
'__module__'
,
None
)
==
'torch.nn.init'
:
if
'tensor'
in
kwargs
:
return
kwargs
[
'tensor'
]
else
:
return
args
[
0
]
if
self
.
device
is
not
None
and
func
in
torch
.
utils
.
_device
.
_device_constructors
(
)
and
kwargs
.
get
(
'device'
)
is
None
:
kwargs
[
'device'
]
=
self
.
device
return
func
(
*
args
,
**
kwargs
)
def
with_empty_init
(
func
):
@
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
with
EmptyInitOnDevice
(
'cpu'
):
return
func
(
*
args
,
**
kwargs
)
return
wrapper
def
culens2mask
(
cu_seqlens
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen
=
None
,
max_seqlen_kv
=
None
,
is_causal
=
False
):
assert
len
(
cu_seqlens
)
==
len
(
cu_seqlens_kv
)
"q k v should have same bsz..."
bsz
=
len
(
cu_seqlens
)
-
1
seqlens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
seqlens_kv
=
cu_seqlens_kv
[
1
:]
-
cu_seqlens_kv
[:
-
1
]
attn_mask
=
torch
.
zeros
(
bsz
,
max_seqlen
,
max_seqlen_kv
,
dtype
=
torch
.
bool
)
for
i
,
(
seq_len
,
seq_len_kv
)
in
enumerate
(
zip
(
seqlens
,
seqlens_kv
)):
if
is_causal
:
attn_mask
[
i
,
:
seq_len
,
:
seq_len_kv
]
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len_kv
),
diagonal
=
1
).
bool
()
else
:
attn_mask
[
i
,
:
seq_len
,
:
seq_len_kv
]
=
torch
.
ones
([
seq_len
,
seq_len_kv
],
dtype
=
torch
.
bool
)
return
attn_mask
Prev
1
…
3
4
5
6
7
8
9
10
11
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment