Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
TS-MODELS-OPT
training
Video-Generation-Model
Commits
c07946d8
Commit
c07946d8
authored
Apr 09, 2026
by
hepj
Browse files
dit & video
parents
Changes
270
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6884 additions
and
0 deletions
+6884
-0
FastVideo-main/fastvideo/models/stepvideo/utils/video_process.py
...eo-main/fastvideo/models/stepvideo/utils/video_process.py
+51
-0
FastVideo-main/fastvideo/models/stepvideo/vae/vae.py
FastVideo-main/fastvideo/models/stepvideo/vae/vae.py
+975
-0
FastVideo-main/fastvideo/sample/call_remote_server_stepvideo.py
...deo-main/fastvideo/sample/call_remote_server_stepvideo.py
+179
-0
FastVideo-main/fastvideo/sample/generate_synthetic.py
FastVideo-main/fastvideo/sample/generate_synthetic.py
+118
-0
FastVideo-main/fastvideo/sample/sample_t2v_hunyuan.py
FastVideo-main/fastvideo/sample/sample_t2v_hunyuan.py
+210
-0
FastVideo-main/fastvideo/sample/sample_t2v_hunyuan_STA.py
FastVideo-main/fastvideo/sample/sample_t2v_hunyuan_STA.py
+407
-0
FastVideo-main/fastvideo/sample/sample_t2v_hunyuan_hf.py
FastVideo-main/fastvideo/sample/sample_t2v_hunyuan_hf.py
+311
-0
FastVideo-main/fastvideo/sample/sample_t2v_mochi.py
FastVideo-main/fastvideo/sample/sample_t2v_mochi.py
+151
-0
FastVideo-main/fastvideo/sample/sample_t2v_mochi_no_sp.py
FastVideo-main/fastvideo/sample/sample_t2v_mochi_no_sp.py
+55
-0
FastVideo-main/fastvideo/sample/sample_t2v_stepvideo.py
FastVideo-main/fastvideo/sample/sample_t2v_stepvideo.py
+246
-0
FastVideo-main/fastvideo/sample/sample_t2v_stepvideo_STA.py
FastVideo-main/fastvideo/sample/sample_t2v_stepvideo_STA.py
+374
-0
FastVideo-main/fastvideo/train.py
FastVideo-main/fastvideo/train.py
+695
-0
FastVideo-main/fastvideo/train.py-bak
FastVideo-main/fastvideo/train.py-bak
+664
-0
FastVideo-main/fastvideo/train_back.py
FastVideo-main/fastvideo/train_back.py
+670
-0
FastVideo-main/fastvideo/train_prof.py
FastVideo-main/fastvideo/train_prof.py
+673
-0
FastVideo-main/fastvideo/utils/checkpoint.py
FastVideo-main/fastvideo/utils/checkpoint.py
+286
-0
FastVideo-main/fastvideo/utils/communications.py
FastVideo-main/fastvideo/utils/communications.py
+307
-0
FastVideo-main/fastvideo/utils/dataset_utils.py
FastVideo-main/fastvideo/utils/dataset_utils.py
+342
-0
FastVideo-main/fastvideo/utils/env_utils.py
FastVideo-main/fastvideo/utils/env_utils.py
+38
-0
FastVideo-main/fastvideo/utils/fsdp_util.py
FastVideo-main/fastvideo/utils/fsdp_util.py
+132
-0
No files found.
Too many changes to show.
To preserve performance only
270 of 270+
files are displayed.
Plain diff
Email patch
FastVideo-main/fastvideo/models/stepvideo/utils/video_process.py
0 → 100644
View file @
c07946d8
import
os
import
imageio
import
numpy
as
np
import
torch
class
VideoProcessor
:
def
__init__
(
self
,
save_path
:
str
=
'./results'
,
name_suffix
:
str
=
''
):
self
.
save_path
=
save_path
os
.
makedirs
(
self
.
save_path
,
exist_ok
=
True
)
self
.
name_suffix
=
name_suffix
def
crop2standard540p
(
self
,
vid_array
):
_
,
height
,
width
,
_
=
vid_array
.
shape
height_center
=
height
//
2
width_center
=
width
//
2
if
width_center
>
height_center
:
## horizon mode
return
vid_array
[:,
height_center
-
270
:
height_center
+
270
,
width_center
-
480
:
width_center
+
480
]
elif
width_center
<
height_center
:
## portrait mode
return
vid_array
[:,
height_center
-
480
:
height_center
+
480
,
width_center
-
270
:
width_center
+
270
]
else
:
return
vid_array
def
save_imageio_video
(
self
,
video_array
:
np
.
array
,
output_filename
:
str
,
fps
=
25
,
codec
=
'libx264'
):
ffmpeg_params
=
[
"-vf"
,
"atadenoise=0a=0.1:0b=0.1:1a=0.1:1b=0.1"
,
# denoise
]
with
imageio
.
get_writer
(
output_filename
,
fps
=
fps
,
codec
=
codec
,
ffmpeg_params
=
ffmpeg_params
)
as
vid_writer
:
for
img_array
in
video_array
:
vid_writer
.
append_data
(
img_array
)
def
postprocess_video
(
self
,
video_tensor
,
output_file_name
=
''
,
output_type
=
"mp4"
,
crop2standard540p
=
True
):
if
len
(
self
.
name_suffix
)
==
0
:
video_path
=
os
.
path
.
join
(
self
.
save_path
,
f
"
{
output_file_name
}
.
{
output_type
}
"
)
else
:
video_path
=
os
.
path
.
join
(
self
.
save_path
,
f
"
{
output_file_name
}
-
{
self
.
name_suffix
}
.
{
output_type
}
"
)
video_tensor
=
torch
.
cat
([
t
for
t
in
video_tensor
],
dim
=-
2
)
video_tensor
=
(
video_tensor
.
cpu
().
clamp
(
-
1
,
1
)
+
1
)
*
127.5
video_array
=
video_tensor
.
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
numpy
().
transpose
(
0
,
2
,
3
,
1
)
if
crop2standard540p
:
video_array
=
self
.
crop2standard540p
(
video_array
)
self
.
save_imageio_video
(
video_array
,
video_path
)
print
(
f
"Saved the generated video in
{
video_path
}
"
)
FastVideo-main/fastvideo/models/stepvideo/vae/vae.py
0 → 100644
View file @
c07946d8
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import
torch
from
einops
import
rearrange
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
fastvideo.models.stepvideo.utils
import
with_empty_init
def
base_group_norm
(
x
,
norm_layer
,
act_silu
=
False
,
channel_last
=
False
):
if
hasattr
(
base_group_norm
,
'spatial'
)
and
base_group_norm
.
spatial
:
assert
channel_last
x_shape
=
x
.
shape
x
=
x
.
flatten
(
0
,
1
)
if
channel_last
:
# Permute to NCHW format
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
out
=
F
.
group_norm
(
x
.
contiguous
(),
norm_layer
.
num_groups
,
norm_layer
.
weight
,
norm_layer
.
bias
,
norm_layer
.
eps
)
if
act_silu
:
out
=
F
.
silu
(
out
)
if
channel_last
:
# Permute back to NHWC format
out
=
out
.
permute
(
0
,
2
,
3
,
1
)
out
=
out
.
view
(
x_shape
)
else
:
if
channel_last
:
# Permute to NCHW format
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
out
=
F
.
group_norm
(
x
.
contiguous
(),
norm_layer
.
num_groups
,
norm_layer
.
weight
,
norm_layer
.
bias
,
norm_layer
.
eps
)
if
act_silu
:
out
=
F
.
silu
(
out
)
if
channel_last
:
# Permute back to NHWC format
out
=
out
.
permute
(
0
,
2
,
3
,
1
)
return
out
def
base_conv2d
(
x
,
conv_layer
,
channel_last
=
False
,
residual
=
None
):
if
channel_last
:
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
# NHWC to NCHW
out
=
F
.
conv2d
(
x
,
conv_layer
.
weight
,
conv_layer
.
bias
,
stride
=
conv_layer
.
stride
,
padding
=
conv_layer
.
padding
)
if
residual
is
not
None
:
if
channel_last
:
residual
=
residual
.
permute
(
0
,
3
,
1
,
2
)
# NHWC to NCHW
out
+=
residual
if
channel_last
:
out
=
out
.
permute
(
0
,
2
,
3
,
1
)
# NCHW to NHWC
return
out
def
base_conv3d
(
x
,
conv_layer
,
channel_last
=
False
,
residual
=
None
,
only_return_output
=
False
):
if
only_return_output
:
size
=
cal_outsize
(
x
.
shape
,
conv_layer
.
weight
.
shape
,
conv_layer
.
stride
,
conv_layer
.
padding
)
return
torch
.
empty
(
size
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
if
channel_last
:
x
=
x
.
permute
(
0
,
4
,
1
,
2
,
3
)
# NDHWC to NCDHW
out
=
F
.
conv3d
(
x
,
conv_layer
.
weight
,
conv_layer
.
bias
,
stride
=
conv_layer
.
stride
,
padding
=
conv_layer
.
padding
)
if
residual
is
not
None
:
if
channel_last
:
residual
=
residual
.
permute
(
0
,
4
,
1
,
2
,
3
)
# NDHWC to NCDHW
out
+=
residual
if
channel_last
:
out
=
out
.
permute
(
0
,
2
,
3
,
4
,
1
)
# NCDHW to NDHWC
return
out
def
cal_outsize
(
input_sizes
,
kernel_sizes
,
stride
,
padding
):
stride_d
,
stride_h
,
stride_w
=
stride
padding_d
,
padding_h
,
padding_w
=
padding
dilation_d
,
dilation_h
,
dilation_w
=
1
,
1
,
1
in_d
=
input_sizes
[
1
]
in_h
=
input_sizes
[
2
]
in_w
=
input_sizes
[
3
]
kernel_d
=
kernel_sizes
[
2
]
kernel_h
=
kernel_sizes
[
3
]
kernel_w
=
kernel_sizes
[
4
]
out_channels
=
kernel_sizes
[
0
]
out_d
=
calc_out_
(
in_d
,
padding_d
,
dilation_d
,
kernel_d
,
stride_d
)
out_h
=
calc_out_
(
in_h
,
padding_h
,
dilation_h
,
kernel_h
,
stride_h
)
out_w
=
calc_out_
(
in_w
,
padding_w
,
dilation_w
,
kernel_w
,
stride_w
)
size
=
[
input_sizes
[
0
],
out_d
,
out_h
,
out_w
,
out_channels
]
return
size
def
calc_out_
(
in_size
,
padding
,
dilation
,
kernel
,
stride
):
return
(
in_size
+
2
*
padding
-
dilation
*
(
kernel
-
1
)
-
1
)
//
stride
+
1
def
base_conv3d_channel_last
(
x
,
conv_layer
,
residual
=
None
):
in_numel
=
x
.
numel
()
out_numel
=
int
(
x
.
numel
()
*
conv_layer
.
out_channels
/
conv_layer
.
in_channels
)
if
(
in_numel
>=
2
**
30
)
or
(
out_numel
>=
2
**
30
):
assert
conv_layer
.
stride
[
0
]
==
1
,
"time split asks time stride = 1"
B
,
T
,
H
,
W
,
C
=
x
.
shape
K
=
conv_layer
.
kernel_size
[
0
]
chunks
=
4
chunk_size
=
T
//
chunks
if
residual
is
None
:
out_nhwc
=
base_conv3d
(
x
,
conv_layer
,
channel_last
=
True
,
residual
=
residual
,
only_return_output
=
True
)
else
:
out_nhwc
=
residual
assert
B
==
1
for
i
in
range
(
chunks
):
if
i
==
chunks
-
1
:
xi
=
x
[:
1
,
chunk_size
*
i
:]
out_nhwci
=
out_nhwc
[:
1
,
chunk_size
*
i
:]
else
:
xi
=
x
[:
1
,
chunk_size
*
i
:
chunk_size
*
(
i
+
1
)
+
K
-
1
]
out_nhwci
=
out_nhwc
[:
1
,
chunk_size
*
i
:
chunk_size
*
(
i
+
1
)]
if
residual
is
not
None
:
if
i
==
chunks
-
1
:
ri
=
residual
[:
1
,
chunk_size
*
i
:]
else
:
ri
=
residual
[:
1
,
chunk_size
*
i
:
chunk_size
*
(
i
+
1
)]
else
:
ri
=
None
out_nhwci
.
copy_
(
base_conv3d
(
xi
,
conv_layer
,
channel_last
=
True
,
residual
=
ri
))
else
:
out_nhwc
=
base_conv3d
(
x
,
conv_layer
,
channel_last
=
True
,
residual
=
residual
)
return
out_nhwc
class
Upsample2D
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_conv_transpose
=
use_conv_transpose
if
use_conv
:
self
.
conv
=
nn
.
Conv2d
(
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
assert
"Not Supported"
self
.
conv
=
nn
.
ConvTranspose2d
(
channels
,
self
.
out_channels
,
4
,
2
,
1
)
def
forward
(
self
,
x
,
output_size
=
None
):
assert
x
.
shape
[
-
1
]
==
self
.
channels
if
self
.
use_conv_transpose
:
return
self
.
conv
(
x
)
if
output_size
is
None
:
x
=
F
.
interpolate
(
x
.
permute
(
0
,
3
,
1
,
2
).
to
(
memory_format
=
torch
.
channels_last
),
scale_factor
=
2.0
,
mode
=
'nearest'
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
else
:
x
=
F
.
interpolate
(
x
.
permute
(
0
,
3
,
1
,
2
).
to
(
memory_format
=
torch
.
channels_last
),
size
=
output_size
,
mode
=
'nearest'
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
# x = self.conv(x)
x
=
base_conv2d
(
x
,
self
.
conv
,
channel_last
=
True
)
return
x
class
Downsample2D
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
padding
=
padding
stride
=
2
if
use_conv
:
self
.
conv
=
nn
.
Conv2d
(
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
conv
=
nn
.
AvgPool2d
(
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
-
1
]
==
self
.
channels
if
self
.
use_conv
and
self
.
padding
==
0
:
pad
=
(
0
,
0
,
0
,
1
,
0
,
1
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
assert
x
.
shape
[
-
1
]
==
self
.
channels
# x = self.conv(x)
x
=
base_conv2d
(
x
,
self
.
conv
,
channel_last
=
True
)
return
x
class
CausalConv
(
nn
.
Module
):
def
__init__
(
self
,
chan_in
,
chan_out
,
kernel_size
,
**
kwargs
):
super
().
__init__
()
if
isinstance
(
kernel_size
,
int
):
kernel_size
=
kernel_size
if
isinstance
(
kernel_size
,
tuple
)
else
((
kernel_size
,
)
*
3
)
time_kernel_size
,
height_kernel_size
,
width_kernel_size
=
kernel_size
self
.
dilation
=
kwargs
.
pop
(
'dilation'
,
1
)
self
.
stride
=
kwargs
.
pop
(
'stride'
,
1
)
if
isinstance
(
self
.
stride
,
int
):
self
.
stride
=
(
self
.
stride
,
1
,
1
)
time_pad
=
self
.
dilation
*
(
time_kernel_size
-
1
)
+
max
((
1
-
self
.
stride
[
0
]),
0
)
height_pad
=
height_kernel_size
//
2
width_pad
=
width_kernel_size
//
2
self
.
time_causal_padding
=
(
width_pad
,
width_pad
,
height_pad
,
height_pad
,
time_pad
,
0
)
self
.
time_uncausal_padding
=
(
width_pad
,
width_pad
,
height_pad
,
height_pad
,
0
,
0
)
self
.
conv
=
nn
.
Conv3d
(
chan_in
,
chan_out
,
kernel_size
,
stride
=
self
.
stride
,
dilation
=
self
.
dilation
,
**
kwargs
)
self
.
is_first_run
=
True
def
forward
(
self
,
x
,
is_init
=
True
,
residual
=
None
):
x
=
nn
.
functional
.
pad
(
x
,
self
.
time_causal_padding
if
is_init
else
self
.
time_uncausal_padding
)
x
=
self
.
conv
(
x
)
if
residual
is
not
None
:
x
.
add_
(
residual
)
return
x
class
ChannelDuplicatingPixelUnshuffleUpSampleLayer3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
factor
:
int
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
factor
=
factor
assert
out_channels
*
factor
**
3
%
in_channels
==
0
self
.
repeats
=
out_channels
*
factor
**
3
//
in_channels
def
forward
(
self
,
x
:
torch
.
Tensor
,
is_init
=
True
)
->
torch
.
Tensor
:
x
=
x
.
repeat_interleave
(
self
.
repeats
,
dim
=
1
)
x
=
x
.
view
(
x
.
size
(
0
),
self
.
out_channels
,
self
.
factor
,
self
.
factor
,
self
.
factor
,
x
.
size
(
2
),
x
.
size
(
3
),
x
.
size
(
4
))
x
=
x
.
permute
(
0
,
1
,
5
,
2
,
6
,
3
,
7
,
4
).
contiguous
()
x
=
x
.
view
(
x
.
size
(
0
),
self
.
out_channels
,
x
.
size
(
2
)
*
self
.
factor
,
x
.
size
(
4
)
*
self
.
factor
,
x
.
size
(
6
)
*
self
.
factor
)
x
=
x
[:,
:,
self
.
factor
-
1
:,
:,
:]
return
x
class
ConvPixelShuffleUpSampleLayer3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
factor
:
int
,
):
super
().
__init__
()
self
.
factor
=
factor
out_ratio
=
factor
**
3
self
.
conv
=
CausalConv
(
in_channels
,
out_channels
*
out_ratio
,
kernel_size
=
kernel_size
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
is_init
=
True
)
->
torch
.
Tensor
:
x
=
self
.
conv
(
x
,
is_init
)
x
=
self
.
pixel_shuffle_3d
(
x
,
self
.
factor
)
return
x
@
staticmethod
def
pixel_shuffle_3d
(
x
:
torch
.
Tensor
,
factor
:
int
)
->
torch
.
Tensor
:
batch_size
,
channels
,
depth
,
height
,
width
=
x
.
size
()
new_channels
=
channels
//
(
factor
**
3
)
new_depth
=
depth
*
factor
new_height
=
height
*
factor
new_width
=
width
*
factor
x
=
x
.
view
(
batch_size
,
new_channels
,
factor
,
factor
,
factor
,
depth
,
height
,
width
)
x
=
x
.
permute
(
0
,
1
,
5
,
2
,
6
,
3
,
7
,
4
).
contiguous
()
x
=
x
.
view
(
batch_size
,
new_channels
,
new_depth
,
new_height
,
new_width
)
x
=
x
[:,
:,
factor
-
1
:,
:,
:]
return
x
class
ConvPixelUnshuffleDownSampleLayer3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
factor
:
int
,
):
super
().
__init__
()
self
.
factor
=
factor
out_ratio
=
factor
**
3
assert
out_channels
%
out_ratio
==
0
self
.
conv
=
CausalConv
(
in_channels
,
out_channels
//
out_ratio
,
kernel_size
=
kernel_size
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
is_init
=
True
)
->
torch
.
Tensor
:
x
=
self
.
conv
(
x
,
is_init
)
x
=
self
.
pixel_unshuffle_3d
(
x
,
self
.
factor
)
return
x
@
staticmethod
def
pixel_unshuffle_3d
(
x
:
torch
.
Tensor
,
factor
:
int
)
->
torch
.
Tensor
:
pad
=
(
0
,
0
,
0
,
0
,
factor
-
1
,
0
)
# (left, right, top, bottom, front, back)
x
=
F
.
pad
(
x
,
pad
)
B
,
C
,
D
,
H
,
W
=
x
.
shape
x
=
x
.
view
(
B
,
C
,
D
//
factor
,
factor
,
H
//
factor
,
factor
,
W
//
factor
,
factor
)
x
=
x
.
permute
(
0
,
1
,
3
,
5
,
7
,
2
,
4
,
6
).
contiguous
()
x
=
x
.
view
(
B
,
C
*
factor
**
3
,
D
//
factor
,
H
//
factor
,
W
//
factor
)
return
x
class
PixelUnshuffleChannelAveragingDownSampleLayer3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
factor
:
int
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
factor
=
factor
assert
in_channels
*
factor
**
3
%
out_channels
==
0
self
.
group_size
=
in_channels
*
factor
**
3
//
out_channels
def
forward
(
self
,
x
:
torch
.
Tensor
,
is_init
=
True
)
->
torch
.
Tensor
:
pad
=
(
0
,
0
,
0
,
0
,
self
.
factor
-
1
,
0
)
# (left, right, top, bottom, front, back)
x
=
F
.
pad
(
x
,
pad
)
B
,
C
,
D
,
H
,
W
=
x
.
shape
x
=
x
.
view
(
B
,
C
,
D
//
self
.
factor
,
self
.
factor
,
H
//
self
.
factor
,
self
.
factor
,
W
//
self
.
factor
,
self
.
factor
)
x
=
x
.
permute
(
0
,
1
,
3
,
5
,
7
,
2
,
4
,
6
).
contiguous
()
x
=
x
.
view
(
B
,
C
*
self
.
factor
**
3
,
D
//
self
.
factor
,
H
//
self
.
factor
,
W
//
self
.
factor
)
x
=
x
.
view
(
B
,
self
.
out_channels
,
self
.
group_size
,
D
//
self
.
factor
,
H
//
self
.
factor
,
W
//
self
.
factor
)
x
=
x
.
mean
(
dim
=
2
)
return
x
def
base_group_norm_with_zero_pad
(
x
,
norm_layer
,
act_silu
=
True
,
pad_size
=
2
):
out_shape
=
list
(
x
.
shape
)
out_shape
[
1
]
+=
pad_size
out
=
torch
.
empty
(
out_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
[:,
pad_size
:]
=
base_group_norm
(
x
,
norm_layer
,
act_silu
=
act_silu
,
channel_last
=
True
)
out
[:,
:
pad_size
]
=
0
return
out
class
CausalConvChannelLast
(
CausalConv
):
def
__init__
(
self
,
chan_in
,
chan_out
,
kernel_size
,
**
kwargs
):
super
().
__init__
(
chan_in
,
chan_out
,
kernel_size
,
**
kwargs
)
self
.
time_causal_padding
=
(
0
,
0
)
+
self
.
time_causal_padding
self
.
time_uncausal_padding
=
(
0
,
0
)
+
self
.
time_uncausal_padding
def
forward
(
self
,
x
,
is_init
=
True
,
residual
=
None
):
if
self
.
is_first_run
:
self
.
is_first_run
=
False
# self.conv.weight = nn.Parameter(self.conv.weight.permute(0,2,3,4,1).contiguous())
x
=
nn
.
functional
.
pad
(
x
,
self
.
time_causal_padding
if
is_init
else
self
.
time_uncausal_padding
)
x
=
base_conv3d_channel_last
(
x
,
self
.
conv
,
residual
=
residual
)
return
x
class
CausalConvAfterNorm
(
CausalConv
):
def
__init__
(
self
,
chan_in
,
chan_out
,
kernel_size
,
**
kwargs
):
super
().
__init__
(
chan_in
,
chan_out
,
kernel_size
,
**
kwargs
)
if
self
.
time_causal_padding
==
(
1
,
1
,
1
,
1
,
2
,
0
):
self
.
conv
=
nn
.
Conv3d
(
chan_in
,
chan_out
,
kernel_size
,
stride
=
self
.
stride
,
dilation
=
self
.
dilation
,
padding
=
(
0
,
1
,
1
),
**
kwargs
)
else
:
self
.
conv
=
nn
.
Conv3d
(
chan_in
,
chan_out
,
kernel_size
,
stride
=
self
.
stride
,
dilation
=
self
.
dilation
,
**
kwargs
)
self
.
is_first_run
=
True
def
forward
(
self
,
x
,
is_init
=
True
,
residual
=
None
):
if
self
.
is_first_run
:
self
.
is_first_run
=
False
if
self
.
time_causal_padding
==
(
1
,
1
,
1
,
1
,
2
,
0
):
pass
else
:
x
=
nn
.
functional
.
pad
(
x
,
self
.
time_causal_padding
).
contiguous
()
x
=
base_conv3d_channel_last
(
x
,
self
.
conv
,
residual
=
residual
)
return
x
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
norm
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
)
self
.
q
=
CausalConvChannelLast
(
in_channels
,
in_channels
,
kernel_size
=
1
)
self
.
k
=
CausalConvChannelLast
(
in_channels
,
in_channels
,
kernel_size
=
1
)
self
.
v
=
CausalConvChannelLast
(
in_channels
,
in_channels
,
kernel_size
=
1
)
self
.
proj_out
=
CausalConvChannelLast
(
in_channels
,
in_channels
,
kernel_size
=
1
)
def
attention
(
self
,
x
,
is_init
=
True
):
x
=
base_group_norm
(
x
,
self
.
norm
,
act_silu
=
False
,
channel_last
=
True
)
q
=
self
.
q
(
x
,
is_init
)
k
=
self
.
k
(
x
,
is_init
)
v
=
self
.
v
(
x
,
is_init
)
b
,
t
,
h
,
w
,
c
=
q
.
shape
q
,
k
,
v
=
map
(
lambda
x
:
rearrange
(
x
,
"b t h w c -> b 1 (t h w) c"
),
(
q
,
k
,
v
))
x
=
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
is_causal
=
True
)
x
=
rearrange
(
x
,
"b 1 (t h w) c -> b t h w c"
,
t
=
t
,
h
=
h
,
w
=
w
)
return
x
def
forward
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
3
,
4
,
1
).
contiguous
()
h
=
self
.
attention
(
x
)
x
=
self
.
proj_out
(
h
,
residual
=
x
)
x
=
x
.
permute
(
0
,
4
,
1
,
2
,
3
)
return
x
class
Resnet3DBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
=
None
,
temb_channels
=
512
,
conv_shortcut
=
False
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
norm1
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
)
self
.
conv1
=
CausalConvAfterNorm
(
in_channels
,
out_channels
,
kernel_size
=
3
)
if
temb_channels
>
0
:
self
.
temb_proj
=
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
out_channels
)
self
.
conv2
=
CausalConvAfterNorm
(
out_channels
,
out_channels
,
kernel_size
=
3
)
assert
conv_shortcut
is
False
self
.
use_conv_shortcut
=
conv_shortcut
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
CausalConvAfterNorm
(
in_channels
,
out_channels
,
kernel_size
=
3
)
else
:
self
.
nin_shortcut
=
CausalConvAfterNorm
(
in_channels
,
out_channels
,
kernel_size
=
1
)
def
forward
(
self
,
x
,
temb
=
None
,
is_init
=
True
):
x
=
x
.
permute
(
0
,
2
,
3
,
4
,
1
).
contiguous
()
h
=
base_group_norm_with_zero_pad
(
x
,
self
.
norm1
,
act_silu
=
True
,
pad_size
=
2
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nn
.
functional
.
silu
(
temb
))[:,
:,
None
,
None
]
x
=
self
.
nin_shortcut
(
x
)
if
self
.
in_channels
!=
self
.
out_channels
else
x
h
=
base_group_norm_with_zero_pad
(
h
,
self
.
norm2
,
act_silu
=
True
,
pad_size
=
2
)
x
=
self
.
conv2
(
h
,
residual
=
x
)
x
=
x
.
permute
(
0
,
4
,
1
,
2
,
3
)
return
x
class
Downsample3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
,
stride
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
with_conv
:
self
.
conv
=
CausalConv
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
stride
)
def
forward
(
self
,
x
,
is_init
=
True
):
if
self
.
with_conv
:
x
=
self
.
conv
(
x
,
is_init
)
else
:
x
=
nn
.
functional
.
avg_pool3d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
VideoEncoder
(
nn
.
Module
):
def
__init__
(
self
,
ch
=
32
,
ch_mult
=
(
4
,
8
,
16
,
16
),
num_res_blocks
=
2
,
in_channels
=
3
,
z_channels
=
16
,
double_z
=
True
,
down_sampling_layer
=
[
1
,
2
],
resamp_with_conv
=
True
,
version
=
1
,
):
super
().
__init__
()
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
# downsampling
self
.
conv_in
=
CausalConv
(
in_channels
,
ch
,
kernel_size
=
3
)
self
.
down_sampling_layer
=
down_sampling_layer
in_ch_mult
=
(
1
,
)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
Resnet3DBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
temb_ch
))
block_in
=
block_out
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
if
i_level
in
self
.
down_sampling_layer
:
down
.
downsample
=
Downsample3D
(
block_in
,
resamp_with_conv
,
stride
=
(
2
,
2
,
2
))
else
:
down
.
downsample
=
Downsample2D
(
block_in
,
resamp_with_conv
,
padding
=
0
)
#DIFF
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
Resnet3DBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
temb_ch
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
Resnet3DBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
temb_ch
)
# end
self
.
norm_out
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
block_in
)
self
.
version
=
version
if
version
==
2
:
channels
=
4
*
z_channels
*
2
**
3
self
.
conv_patchify
=
ConvPixelUnshuffleDownSampleLayer3D
(
block_in
,
channels
,
kernel_size
=
3
,
factor
=
2
)
self
.
shortcut_pathify
=
PixelUnshuffleChannelAveragingDownSampleLayer3D
(
block_in
,
channels
,
2
)
self
.
shortcut_out
=
PixelUnshuffleChannelAveragingDownSampleLayer3D
(
channels
,
2
*
z_channels
if
double_z
else
z_channels
,
1
)
self
.
conv_out
=
CausalConvChannelLast
(
channels
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
)
else
:
self
.
conv_out
=
CausalConvAfterNorm
(
block_in
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
)
@
torch
.
inference_mode
()
def
forward
(
self
,
x
,
video_frame_num
,
is_init
=
True
):
# timestep embedding
temb
=
None
t
=
video_frame_num
# downsampling
h
=
self
.
conv_in
(
x
,
is_init
)
# make it real channel last, but behave like normal layout
h
=
h
.
permute
(
0
,
2
,
3
,
4
,
1
).
contiguous
().
permute
(
0
,
4
,
1
,
2
,
3
)
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
h
,
temb
,
is_init
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
if
isinstance
(
self
.
down
[
i_level
].
downsample
,
Downsample2D
):
_
,
_
,
t
,
_
,
_
=
h
.
shape
h
=
rearrange
(
h
,
"b c t h w -> (b t) h w c"
,
t
=
t
)
h
=
self
.
down
[
i_level
].
downsample
(
h
)
h
=
rearrange
(
h
,
"(b t) h w c -> b c t h w"
,
t
=
t
)
else
:
h
=
self
.
down
[
i_level
].
downsample
(
h
,
is_init
)
h
=
self
.
mid
.
block_1
(
h
,
temb
,
is_init
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
,
is_init
)
h
=
h
.
permute
(
0
,
2
,
3
,
4
,
1
).
contiguous
()
# b c l h w -> b l h w c
if
self
.
version
==
2
:
h
=
base_group_norm
(
h
,
self
.
norm_out
,
act_silu
=
True
,
channel_last
=
True
)
h
=
h
.
permute
(
0
,
4
,
1
,
2
,
3
).
contiguous
()
shortcut
=
self
.
shortcut_pathify
(
h
,
is_init
)
h
=
self
.
conv_patchify
(
h
,
is_init
)
h
=
h
.
add_
(
shortcut
)
shortcut
=
self
.
shortcut_out
(
h
,
is_init
).
permute
(
0
,
2
,
3
,
4
,
1
)
h
=
self
.
conv_out
(
h
.
permute
(
0
,
2
,
3
,
4
,
1
).
contiguous
(),
is_init
)
h
=
h
.
add_
(
shortcut
)
else
:
h
=
base_group_norm_with_zero_pad
(
h
,
self
.
norm_out
,
act_silu
=
True
,
pad_size
=
2
)
h
=
self
.
conv_out
(
h
,
is_init
)
h
=
h
.
permute
(
0
,
4
,
1
,
2
,
3
)
# b l h w c -> b c l h w
h
=
rearrange
(
h
,
"b c t h w -> b t c h w"
)
return
h
class
Res3DBlockUpsample
(
nn
.
Module
):
def
__init__
(
self
,
input_filters
,
num_filters
,
down_sampling_stride
,
down_sampling
=
False
):
super
().
__init__
()
self
.
input_filters
=
input_filters
self
.
num_filters
=
num_filters
self
.
act_
=
nn
.
SiLU
(
inplace
=
True
)
self
.
conv1
=
CausalConvChannelLast
(
num_filters
,
num_filters
,
kernel_size
=
[
3
,
3
,
3
])
self
.
norm1
=
nn
.
GroupNorm
(
32
,
num_filters
)
self
.
conv2
=
CausalConvChannelLast
(
num_filters
,
num_filters
,
kernel_size
=
[
3
,
3
,
3
])
self
.
norm2
=
nn
.
GroupNorm
(
32
,
num_filters
)
self
.
down_sampling
=
down_sampling
if
down_sampling
:
self
.
down_sampling_stride
=
down_sampling_stride
else
:
self
.
down_sampling_stride
=
[
1
,
1
,
1
]
if
num_filters
!=
input_filters
or
down_sampling
:
self
.
conv3
=
CausalConvChannelLast
(
input_filters
,
num_filters
,
kernel_size
=
[
1
,
1
,
1
],
stride
=
self
.
down_sampling_stride
)
self
.
norm3
=
nn
.
GroupNorm
(
32
,
num_filters
)
def
forward
(
self
,
x
,
is_init
=
False
):
x
=
x
.
permute
(
0
,
2
,
3
,
4
,
1
).
contiguous
()
residual
=
x
h
=
self
.
conv1
(
x
,
is_init
)
h
=
base_group_norm
(
h
,
self
.
norm1
,
act_silu
=
True
,
channel_last
=
True
)
h
=
self
.
conv2
(
h
,
is_init
)
h
=
base_group_norm
(
h
,
self
.
norm2
,
act_silu
=
False
,
channel_last
=
True
)
if
self
.
down_sampling
or
self
.
num_filters
!=
self
.
input_filters
:
x
=
self
.
conv3
(
x
,
is_init
)
x
=
base_group_norm
(
x
,
self
.
norm3
,
act_silu
=
False
,
channel_last
=
True
)
h
.
add_
(
x
)
h
=
self
.
act_
(
h
)
if
residual
is
not
None
:
h
.
add_
(
residual
)
h
=
h
.
permute
(
0
,
4
,
1
,
2
,
3
)
return
h
class
Upsample3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
scale_factor
=
2
):
super
().
__init__
()
self
.
scale_factor
=
scale_factor
self
.
conv3d
=
Res3DBlockUpsample
(
input_filters
=
in_channels
,
num_filters
=
in_channels
,
down_sampling_stride
=
(
1
,
1
,
1
),
down_sampling
=
False
)
def
forward
(
self
,
x
,
is_init
=
True
,
is_split
=
True
):
b
,
c
,
t
,
h
,
w
=
x
.
shape
# x = x.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3).to(memory_format=torch.channels_last_3d)
if
is_split
:
split_size
=
c
//
8
x_slices
=
torch
.
split
(
x
,
split_size
,
dim
=
1
)
x
=
[
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
self
.
scale_factor
)
for
x
in
x_slices
]
x
=
torch
.
cat
(
x
,
dim
=
1
)
else
:
x
=
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
self
.
scale_factor
)
x
=
self
.
conv3d
(
x
,
is_init
)
return
x
class
VideoDecoder
(
nn
.
Module
):
def
__init__
(
self
,
ch
=
128
,
z_channels
=
16
,
out_channels
=
3
,
ch_mult
=
(
1
,
2
,
4
,
4
),
num_res_blocks
=
2
,
temporal_up_layers
=
[
2
,
3
],
temporal_downsample
=
4
,
resamp_with_conv
=
True
,
version
=
1
,
):
super
().
__init__
()
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
temporal_downsample
=
temporal_downsample
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
self
.
version
=
version
if
version
==
2
:
channels
=
4
*
z_channels
*
2
**
3
self
.
conv_in
=
CausalConv
(
z_channels
,
channels
,
kernel_size
=
3
)
self
.
shortcut_in
=
ChannelDuplicatingPixelUnshuffleUpSampleLayer3D
(
z_channels
,
channels
,
1
)
self
.
conv_unpatchify
=
ConvPixelShuffleUpSampleLayer3D
(
channels
,
block_in
,
kernel_size
=
3
,
factor
=
2
)
self
.
shortcut_unpathify
=
ChannelDuplicatingPixelUnshuffleUpSampleLayer3D
(
channels
,
block_in
,
2
)
else
:
self
.
conv_in
=
CausalConv
(
z_channels
,
block_in
,
kernel_size
=
3
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
Resnet3DBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
temb_ch
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
Resnet3DBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
temb_ch
)
# upsampling
self
.
up_id
=
len
(
temporal_up_layers
)
self
.
video_frame_num
=
1
self
.
cur_video_frame_num
=
self
.
video_frame_num
//
2
**
self
.
up_id
+
1
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
Resnet3DBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
temb_ch
))
block_in
=
block_out
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
if
i_level
in
temporal_up_layers
:
up
.
upsample
=
Upsample3D
(
block_in
)
self
.
cur_video_frame_num
=
self
.
cur_video_frame_num
*
2
else
:
up
.
upsample
=
Upsample2D
(
block_in
,
resamp_with_conv
)
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
block_in
)
self
.
conv_out
=
CausalConvAfterNorm
(
block_in
,
out_channels
,
kernel_size
=
3
)
@
torch
.
inference_mode
()
def
forward
(
self
,
z
,
is_init
=
True
):
z
=
rearrange
(
z
,
"b t c h w -> b c t h w"
)
h
=
self
.
conv_in
(
z
,
is_init
=
is_init
)
if
self
.
version
==
2
:
shortcut
=
self
.
shortcut_in
(
z
,
is_init
=
is_init
)
h
=
h
.
add_
(
shortcut
)
shortcut
=
self
.
shortcut_unpathify
(
h
,
is_init
=
is_init
)
h
=
self
.
conv_unpatchify
(
h
,
is_init
=
is_init
)
h
=
h
.
add_
(
shortcut
)
temb
=
None
h
=
h
.
permute
(
0
,
2
,
3
,
4
,
1
).
contiguous
().
permute
(
0
,
4
,
1
,
2
,
3
)
h
=
self
.
mid
.
block_1
(
h
,
temb
,
is_init
=
is_init
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
h
.
permute
(
0
,
2
,
3
,
4
,
1
).
contiguous
().
permute
(
0
,
4
,
1
,
2
,
3
)
h
=
self
.
mid
.
block_2
(
h
,
temb
,
is_init
=
is_init
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
h
.
permute
(
0
,
2
,
3
,
4
,
1
).
contiguous
().
permute
(
0
,
4
,
1
,
2
,
3
)
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
,
is_init
=
is_init
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
if
isinstance
(
self
.
up
[
i_level
].
upsample
,
Upsample2D
):
B
=
h
.
size
(
0
)
h
=
h
.
permute
(
0
,
2
,
3
,
4
,
1
).
flatten
(
0
,
1
)
h
=
self
.
up
[
i_level
].
upsample
(
h
)
h
=
h
.
unflatten
(
0
,
(
B
,
-
1
)).
permute
(
0
,
4
,
1
,
2
,
3
)
else
:
h
=
self
.
up
[
i_level
].
upsample
(
h
,
is_init
=
is_init
)
# end
h
=
h
.
permute
(
0
,
2
,
3
,
4
,
1
)
# b c l h w -> b l h w c
h
=
base_group_norm_with_zero_pad
(
h
,
self
.
norm_out
,
act_silu
=
True
,
pad_size
=
2
)
h
=
self
.
conv_out
(
h
)
h
=
h
.
permute
(
0
,
4
,
1
,
2
,
3
)
if
is_init
:
h
=
h
[:,
:,
(
self
.
temporal_downsample
-
1
):]
return
h
def
rms_norm
(
input
,
normalized_shape
,
eps
=
1e-6
):
dtype
=
input
.
dtype
input
=
input
.
to
(
torch
.
float32
)
variance
=
input
.
pow
(
2
).
flatten
(
-
len
(
normalized_shape
)).
mean
(
-
1
)[(...,
)
+
(
None
,
)
*
len
(
normalized_shape
)]
input
=
input
*
torch
.
rsqrt
(
variance
+
eps
)
return
input
.
to
(
dtype
)
class
DiagonalGaussianDistribution
(
object
):
def
__init__
(
self
,
parameters
,
deterministic
=
False
,
rms_norm_mean
=
False
,
only_return_mean
=
False
):
self
.
parameters
=
parameters
self
.
mean
,
self
.
logvar
=
torch
.
chunk
(
parameters
,
2
,
dim
=-
3
)
#N,[X],C,H,W
self
.
logvar
=
torch
.
clamp
(
self
.
logvar
,
-
30.0
,
20.0
)
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
self
.
deterministic
=
deterministic
if
self
.
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
,
device
=
self
.
parameters
.
device
,
dtype
=
self
.
parameters
.
dtype
)
if
rms_norm_mean
:
self
.
mean
=
rms_norm
(
self
.
mean
,
self
.
mean
.
size
()[
1
:])
self
.
only_return_mean
=
only_return_mean
def
sample
(
self
,
generator
=
None
):
# make sure sample is on the same device
# as the parameters and has same dtype
sample
=
torch
.
randn
(
self
.
mean
.
shape
,
generator
=
generator
,
device
=
self
.
parameters
.
device
)
sample
=
sample
.
to
(
dtype
=
self
.
parameters
.
dtype
)
x
=
self
.
mean
+
self
.
std
*
sample
if
self
.
only_return_mean
:
return
self
.
mean
else
:
return
x
class
AutoencoderKL
(
nn
.
Module
):
@
with_empty_init
def
__init__
(
self
,
in_channels
=
3
,
out_channels
=
3
,
z_channels
=
16
,
num_res_blocks
=
2
,
model_path
=
None
,
weight_dict
=
{},
world_size
=
1
,
version
=
1
,
):
super
().
__init__
()
self
.
frame_len
=
17
self
.
latent_len
=
3
if
version
==
2
else
5
base_group_norm
.
spatial
=
True
if
version
==
2
else
False
self
.
encoder
=
VideoEncoder
(
in_channels
=
in_channels
,
z_channels
=
z_channels
,
num_res_blocks
=
num_res_blocks
,
version
=
version
,
)
self
.
decoder
=
VideoDecoder
(
z_channels
=
z_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
version
=
version
,
)
if
model_path
is
not
None
:
weight_dict
=
self
.
init_from_ckpt
(
model_path
)
if
len
(
weight_dict
)
!=
0
:
self
.
load_from_dict
(
weight_dict
)
self
.
convert_channel_last
()
self
.
world_size
=
world_size
def
init_from_ckpt
(
self
,
model_path
):
from
safetensors
import
safe_open
p
=
{}
with
safe_open
(
model_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
k
in
f
.
keys
():
tensor
=
f
.
get_tensor
(
k
)
if
k
.
startswith
(
"decoder.conv_out."
):
k
=
k
.
replace
(
"decoder.conv_out."
,
"decoder.conv_out.conv."
)
p
[
k
]
=
tensor
return
p
def
load_from_dict
(
self
,
p
):
self
.
load_state_dict
(
p
)
def
convert_channel_last
(
self
):
#Conv2d NCHW->NHWC
pass
def
naive_encode
(
self
,
x
,
is_init_image
=
True
):
b
,
len
,
c
,
h
,
w
=
x
.
size
()
x
=
rearrange
(
x
,
'b l c h w -> b c l h w'
).
contiguous
()
z
=
self
.
encoder
(
x
,
len
,
True
)
# 下采样[1, 4, 8, 16, 16]
return
z
@
torch
.
inference_mode
()
def
encode
(
self
,
x
):
# b (nc cf) c h w -> (b nc) cf c h w -> encode -> (b nc) cf c h w -> b (nc cf) c h w
chunks
=
list
(
x
.
split
(
self
.
frame_len
,
dim
=
1
))
for
i
in
range
(
len
(
chunks
)):
chunks
[
i
]
=
self
.
naive_encode
(
chunks
[
i
],
True
)
z
=
torch
.
cat
(
chunks
,
dim
=
1
)
posterior
=
DiagonalGaussianDistribution
(
z
)
return
posterior
.
sample
()
def
decode_naive
(
self
,
z
,
is_init
=
True
):
z
=
z
.
to
(
next
(
self
.
decoder
.
parameters
()).
dtype
)
dec
=
self
.
decoder
(
z
,
is_init
)
return
dec
@
torch
.
inference_mode
()
def
decode
(
self
,
z
):
# b (nc cf) c h w -> (b nc) cf c h w -> decode -> (b nc) c cf h w -> b (nc cf) c h w
chunks
=
list
(
z
.
split
(
self
.
latent_len
,
dim
=
1
))
if
self
.
world_size
>
1
:
chunks_total_num
=
len
(
chunks
)
max_num_per_rank
=
(
chunks_total_num
+
self
.
world_size
-
1
)
//
self
.
world_size
rank
=
torch
.
distributed
.
get_rank
()
chunks_
=
chunks
[
max_num_per_rank
*
rank
:
max_num_per_rank
*
(
rank
+
1
)]
if
len
(
chunks_
)
<
max_num_per_rank
:
chunks_
.
extend
(
chunks
[:
max_num_per_rank
-
len
(
chunks_
)])
chunks
=
chunks_
for
i
in
range
(
len
(
chunks
)):
chunks
[
i
]
=
self
.
decode_naive
(
chunks
[
i
],
True
).
permute
(
0
,
2
,
1
,
3
,
4
)
x
=
torch
.
cat
(
chunks
,
dim
=
1
)
if
self
.
world_size
>
1
:
x_
=
torch
.
empty
([
x
.
size
(
0
),
(
self
.
world_size
*
max_num_per_rank
)
*
self
.
frame_len
,
*
x
.
shape
[
2
:]],
dtype
=
x
.
dtype
,
device
=
x
.
device
)
torch
.
distributed
.
all_gather_into_tensor
(
x_
,
x
)
x
=
x_
[:,
:
chunks_total_num
*
self
.
frame_len
]
x
=
self
.
mix
(
x
)
return
x
def
mix
(
self
,
x
):
remain_scale
=
0.6
mix_scale
=
1.
-
remain_scale
front
=
slice
(
self
.
frame_len
-
1
,
x
.
size
(
1
)
-
1
,
self
.
frame_len
)
back
=
slice
(
self
.
frame_len
,
x
.
size
(
1
),
self
.
frame_len
)
x
[:,
back
]
=
x
[:,
back
]
*
remain_scale
+
x
[:,
front
]
*
mix_scale
x
[:,
front
]
=
x
[:,
front
]
*
remain_scale
+
x
[:,
back
]
*
mix_scale
return
x
FastVideo-main/fastvideo/sample/call_remote_server_stepvideo.py
0 → 100644
View file @
c07946d8
import
argparse
import
os
import
pickle
import
threading
import
torch
from
flask
import
Blueprint
,
Flask
,
Response
,
request
from
flask_restful
import
Api
,
Resource
device
=
f
'cuda:
{
torch
.
cuda
.
device_count
()
-
1
}
'
dtype
=
torch
.
bfloat16
def
parsed_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"StepVideo API Functions"
)
parser
.
add_argument
(
'--model_dir'
,
type
=
str
)
parser
.
add_argument
(
'--clip_dir'
,
type
=
str
,
default
=
'hunyuan_clip'
)
parser
.
add_argument
(
'--llm_dir'
,
type
=
str
,
default
=
'step_llm'
)
parser
.
add_argument
(
'--vae_dir'
,
type
=
str
,
default
=
'vae'
)
parser
.
add_argument
(
'--port'
,
type
=
str
,
default
=
'8080'
)
args
=
parser
.
parse_args
()
return
args
class
StepVaePipeline
(
Resource
):
def
__init__
(
self
,
vae_dir
,
version
=
2
):
self
.
vae
=
self
.
build_vae
(
vae_dir
,
version
)
self
.
scale_factor
=
1.0
def
build_vae
(
self
,
vae_dir
,
version
=
2
):
from
fastvideo.models.stepvideo.vae.vae
import
AutoencoderKL
(
model_name
,
z_channels
)
=
(
"vae_v2.safetensors"
,
64
)
if
version
==
2
else
(
"vae.safetensors"
,
16
)
model_path
=
os
.
path
.
join
(
vae_dir
,
model_name
)
model
=
AutoencoderKL
(
z_channels
=
z_channels
,
model_path
=
model_path
,
version
=
version
,
).
to
(
dtype
).
to
(
device
).
eval
()
print
(
"Initialized vae..."
)
return
model
def
decode
(
self
,
samples
,
*
args
,
**
kwargs
):
with
torch
.
no_grad
():
try
:
dtype
=
next
(
self
.
vae
.
parameters
()).
dtype
device
=
next
(
self
.
vae
.
parameters
()).
device
samples
=
self
.
vae
.
decode
(
samples
.
to
(
dtype
).
to
(
device
)
/
self
.
scale_factor
)
if
hasattr
(
samples
,
'sample'
):
samples
=
samples
.
sample
return
samples
except
:
torch
.
cuda
.
empty_cache
()
return
None
lock
=
threading
.
Lock
()
class
VAEapi
(
Resource
):
def
__init__
(
self
,
vae_pipeline
):
self
.
vae_pipeline
=
vae_pipeline
def
get
(
self
):
with
lock
:
try
:
feature
=
pickle
.
loads
(
request
.
get_data
())
feature
[
'api'
]
=
'vae'
feature
=
{
k
:
v
for
k
,
v
in
feature
.
items
()
if
v
is
not
None
}
video_latents
=
self
.
vae_pipeline
.
decode
(
**
feature
)
response
=
pickle
.
dumps
(
video_latents
)
except
Exception
as
e
:
print
(
"Caught Exception: "
,
e
)
return
Response
(
e
)
return
Response
(
response
)
class
CaptionPipeline
(
Resource
):
def
__init__
(
self
,
llm_dir
,
clip_dir
):
self
.
text_encoder
=
self
.
build_llm
(
llm_dir
)
self
.
clip
=
self
.
build_clip
(
clip_dir
)
def
build_llm
(
self
,
model_dir
):
from
fastvideo.models.stepvideo.text_encoder.stepllm
import
STEP1TextEncoder
text_encoder
=
STEP1TextEncoder
(
model_dir
,
max_length
=
320
).
to
(
dtype
).
to
(
device
).
eval
()
print
(
"Initialized text encoder..."
)
return
text_encoder
def
build_clip
(
self
,
model_dir
):
from
fastvideo.models.stepvideo.text_encoder.clip
import
HunyuanClip
clip
=
HunyuanClip
(
model_dir
,
max_length
=
77
).
to
(
device
).
eval
()
print
(
"Initialized clip encoder..."
)
return
clip
def
embedding
(
self
,
prompts
,
*
args
,
**
kwargs
):
with
torch
.
no_grad
():
try
:
y
,
y_mask
=
self
.
text_encoder
(
prompts
)
clip_embedding
,
_
=
self
.
clip
(
prompts
)
len_clip
=
clip_embedding
.
shape
[
1
]
y_mask
=
torch
.
nn
.
functional
.
pad
(
y_mask
,
(
len_clip
,
0
),
value
=
1
)
## pad attention_mask with clip's length
data
=
{
'y'
:
y
.
detach
().
cpu
(),
'y_mask'
:
y_mask
.
detach
().
cpu
(),
'clip_embedding'
:
clip_embedding
.
to
(
torch
.
bfloat16
).
detach
().
cpu
()
}
return
data
except
Exception
as
err
:
print
(
f
"
{
err
}
"
)
return
None
lock
=
threading
.
Lock
()
class
Captionapi
(
Resource
):
def
__init__
(
self
,
caption_pipeline
):
self
.
caption_pipeline
=
caption_pipeline
def
get
(
self
):
with
lock
:
try
:
feature
=
pickle
.
loads
(
request
.
get_data
())
feature
[
'api'
]
=
'caption'
feature
=
{
k
:
v
for
k
,
v
in
feature
.
items
()
if
v
is
not
None
}
embeddings
=
self
.
caption_pipeline
.
embedding
(
**
feature
)
response
=
pickle
.
dumps
(
embeddings
)
except
Exception
as
e
:
print
(
"Caught Exception: "
,
e
)
return
Response
(
e
)
return
Response
(
response
)
class
RemoteServer
(
object
):
def
__init__
(
self
,
args
)
->
None
:
self
.
app
=
Flask
(
__name__
)
root
=
Blueprint
(
"root"
,
__name__
)
self
.
app
.
register_blueprint
(
root
)
api
=
Api
(
self
.
app
)
self
.
vae_pipeline
=
StepVaePipeline
(
vae_dir
=
os
.
path
.
join
(
args
.
model_dir
,
args
.
vae_dir
))
api
.
add_resource
(
VAEapi
,
"/vae-api"
,
resource_class_args
=
[
self
.
vae_pipeline
],
)
self
.
caption_pipeline
=
CaptionPipeline
(
llm_dir
=
os
.
path
.
join
(
args
.
model_dir
,
args
.
llm_dir
),
clip_dir
=
os
.
path
.
join
(
args
.
model_dir
,
args
.
clip_dir
))
api
.
add_resource
(
Captionapi
,
"/caption-api"
,
resource_class_args
=
[
self
.
caption_pipeline
],
)
def
run
(
self
,
host
=
"0.0.0.0"
,
port
=
8080
):
self
.
app
.
run
(
host
,
port
=
port
,
threaded
=
True
,
debug
=
False
)
if
__name__
==
"__main__"
:
args
=
parsed_args
()
flask_server
=
RemoteServer
(
args
)
flask_server
.
run
(
host
=
"0.0.0.0"
,
port
=
args
.
port
)
FastVideo-main/fastvideo/sample/generate_synthetic.py
0 → 100644
View file @
c07946d8
import
argparse
import
json
import
os
import
torch
import
torch.distributed
as
dist
from
diffusers.utils
import
export_to_video
from
fastvideo.models.mochi_hf.pipeline_mochi
import
MochiPipeline
def
generate_video_and_latent
(
pipe
,
prompt
,
height
,
width
,
num_frames
,
num_inference_steps
,
guidance_scale
):
# Set the random seed for reproducibility
generator
=
torch
.
Generator
(
"cuda"
).
manual_seed
(
12345
)
# Generate videos from the input prompt
noise
,
video
,
latent
,
prompt_embed
,
prompt_attention_mask
=
pipe
(
prompt
=
prompt
,
height
=
height
,
width
=
width
,
num_frames
=
num_frames
,
generator
=
generator
,
num_inference_steps
=
num_inference_steps
,
guidance_scale
=
guidance_scale
,
output_type
=
"latent_and_video"
,
)
# prompt_embed has negative prompt at index 0
return
noise
[
0
],
video
[
0
],
latent
[
0
],
prompt_embed
[
1
],
prompt_attention_mask
[
1
]
# return dummy tensor to debug first
# return torch.zeros(1, 3, 480, 848), torch.zeros(1, 256, 16, 16)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num_frames"
,
type
=
int
,
default
=
163
)
parser
.
add_argument
(
"--height"
,
type
=
int
,
default
=
480
)
parser
.
add_argument
(
"--width"
,
type
=
int
,
default
=
848
)
parser
.
add_argument
(
"--num_inference_steps"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--guidance_scale"
,
type
=
float
,
default
=
4.5
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
default
=
"data/mochi"
)
parser
.
add_argument
(
"--prompt_path"
,
type
=
str
,
default
=
"data/dummyVid/videos2caption.json"
)
parser
.
add_argument
(
"--dataset_output_dir"
,
type
=
str
,
default
=
"data/dummySynthetic"
)
args
=
parser
.
parse_args
()
local_rank
=
int
(
os
.
getenv
(
"RANK"
,
0
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
1
))
print
(
"world_size"
,
world_size
,
"local rank"
,
local_rank
)
torch
.
cuda
.
set_device
(
local_rank
)
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
,
world_size
=
world_size
,
rank
=
local_rank
)
if
not
isinstance
(
args
.
prompt_path
,
list
):
args
.
prompt_path
=
[
args
.
prompt_path
]
if
len
(
args
.
prompt_path
)
==
1
and
args
.
prompt_path
[
0
].
endswith
(
"txt"
):
text_prompt
=
open
(
args
.
prompt_path
[
0
],
"r"
).
readlines
()
text_prompt
=
[
i
.
strip
()
for
i
in
text_prompt
]
pipe
=
MochiPipeline
.
from_pretrained
(
args
.
model_path
,
torch_dtype
=
torch
.
bfloat16
)
pipe
.
enable_vae_tiling
()
pipe
.
enable_model_cpu_offload
(
gpu_id
=
local_rank
)
# make dir if not exist
os
.
makedirs
(
args
.
dataset_output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
args
.
dataset_output_dir
,
"noise"
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
args
.
dataset_output_dir
,
"video"
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
args
.
dataset_output_dir
,
"latent"
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
args
.
dataset_output_dir
,
"prompt_embed"
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
args
.
dataset_output_dir
,
"prompt_attention_mask"
),
exist_ok
=
True
)
data
=
[]
for
i
,
prompt
in
enumerate
(
text_prompt
):
if
i
%
world_size
!=
local_rank
:
continue
(
noise
,
video
,
latent
,
prompt_embed
,
prompt_attention_mask
,
)
=
generate_video_and_latent
(
pipe
,
prompt
,
args
.
height
,
args
.
width
,
args
.
num_frames
,
args
.
num_inference_steps
,
args
.
guidance_scale
,
)
# save latent
video_name
=
str
(
i
)
noise_path
=
os
.
path
.
join
(
args
.
dataset_output_dir
,
"noise"
,
video_name
+
".pt"
)
latent_path
=
os
.
path
.
join
(
args
.
dataset_output_dir
,
"latent"
,
video_name
+
".pt"
)
prompt_embed_path
=
os
.
path
.
join
(
args
.
dataset_output_dir
,
"prompt_embed"
,
video_name
+
".pt"
)
video_path
=
os
.
path
.
join
(
args
.
dataset_output_dir
,
"video"
,
video_name
+
".mp4"
)
prompt_attention_mask_path
=
os
.
path
.
join
(
args
.
dataset_output_dir
,
"prompt_attention_mask"
,
video_name
+
".pt"
)
# save latent
torch
.
save
(
noise
,
noise_path
)
torch
.
save
(
latent
,
latent_path
)
torch
.
save
(
prompt_embed
,
prompt_embed_path
)
torch
.
save
(
prompt_attention_mask
,
prompt_attention_mask_path
)
export_to_video
(
video
,
video_path
,
fps
=
30
)
item
=
{}
item
[
"cap"
]
=
prompt
item
[
"video"
]
=
video_name
+
".mp4"
item
[
"noise"
]
=
video_name
+
".pt"
item
[
"latent_path"
]
=
video_name
+
".pt"
item
[
"prompt_embed_path"
]
=
video_name
+
".pt"
item
[
"prompt_attention_mask"
]
=
video_name
+
".pt"
data
.
append
(
item
)
dist
.
barrier
()
local_data
=
data
gathered_data
=
[
None
]
*
world_size
dist
.
all_gather_object
(
gathered_data
,
local_data
)
# save json
if
local_rank
==
0
:
all_data
=
[
item
for
sublist
in
gathered_data
for
item
in
sublist
]
with
open
(
os
.
path
.
join
(
args
.
dataset_output_dir
,
"videos2caption.json"
),
"w"
)
as
f
:
json
.
dump
(
all_data
,
f
,
indent
=
4
)
FastVideo-main/fastvideo/sample/sample_t2v_hunyuan.py
0 → 100644
View file @
c07946d8
import
argparse
import
os
from
pathlib
import
Path
import
imageio
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torchvision
from
einops
import
rearrange
from
fastvideo.models.hunyuan.inference
import
HunyuanVideoSampler
from
fastvideo.utils.parallel_states
import
initialize_sequence_parallel_state
,
nccl_info
def
initialize_distributed
():
local_rank
=
int
(
os
.
getenv
(
"RANK"
,
0
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
1
))
print
(
"world_size"
,
world_size
)
torch
.
cuda
.
set_device
(
local_rank
)
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
,
world_size
=
world_size
,
rank
=
local_rank
)
initialize_sequence_parallel_state
(
world_size
)
def
main
(
args
):
initialize_distributed
()
print
(
nccl_info
.
sp_size
)
print
(
args
)
models_root_path
=
Path
(
args
.
model_path
)
if
not
models_root_path
.
exists
():
raise
ValueError
(
f
"`models_root` not exists:
{
models_root_path
}
"
)
# Create save folder to save the samples
save_path
=
args
.
output_path
os
.
makedirs
(
os
.
path
.
dirname
(
save_path
),
exist_ok
=
True
)
# Load models
hunyuan_video_sampler
=
HunyuanVideoSampler
.
from_pretrained
(
models_root_path
,
args
=
args
)
# Get the updated args
args
=
hunyuan_video_sampler
.
args
if
args
.
prompt
.
endswith
(
'.txt'
):
with
open
(
args
.
prompt
)
as
f
:
prompts
=
[
line
.
strip
()
for
line
in
f
.
readlines
()]
else
:
prompts
=
[
args
.
prompt
]
for
prompt
in
prompts
:
outputs
=
hunyuan_video_sampler
.
predict
(
prompt
=
prompt
,
height
=
args
.
height
,
width
=
args
.
width
,
video_length
=
args
.
num_frames
,
seed
=
args
.
seed
,
negative_prompt
=
args
.
neg_prompt
,
infer_steps
=
args
.
num_inference_steps
,
guidance_scale
=
args
.
guidance_scale
,
num_videos_per_prompt
=
args
.
num_videos
,
flow_shift
=
args
.
flow_shift
,
batch_size
=
args
.
batch_size
,
embedded_guidance_scale
=
args
.
embedded_cfg_scale
,
)
videos
=
rearrange
(
outputs
[
"samples"
],
"b c t h w -> t b c h w"
)
outputs
=
[]
for
x
in
videos
:
x
=
torchvision
.
utils
.
make_grid
(
x
,
nrow
=
6
)
x
=
x
.
transpose
(
0
,
1
).
transpose
(
1
,
2
).
squeeze
(
-
1
)
outputs
.
append
((
x
*
255
).
numpy
().
astype
(
np
.
uint8
))
os
.
makedirs
(
os
.
path
.
dirname
(
args
.
output_path
),
exist_ok
=
True
)
imageio
.
mimsave
(
os
.
path
.
join
(
args
.
output_path
,
f
"
{
prompt
[:
100
]
}
.mp4"
),
outputs
,
fps
=
args
.
fps
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
# Basic parameters
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
help
=
"prompt file for inference"
)
parser
.
add_argument
(
"--num_frames"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--height"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--width"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--num_inference_steps"
,
type
=
int
,
default
=
50
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
default
=
"data/hunyuan"
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
default
=
"./outputs/video"
)
parser
.
add_argument
(
"--fps"
,
type
=
int
,
default
=
24
)
# Additional parameters
parser
.
add_argument
(
"--denoise-type"
,
type
=
str
,
default
=
"flow"
,
help
=
"Denoise type for noised inputs."
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
None
,
help
=
"Seed for evaluation."
)
parser
.
add_argument
(
"--neg_prompt"
,
type
=
str
,
default
=
None
,
help
=
"Negative prompt for sampling."
)
parser
.
add_argument
(
"--guidance_scale"
,
type
=
float
,
default
=
1.0
,
help
=
"Classifier free guidance scale."
,
)
parser
.
add_argument
(
"--embedded_cfg_scale"
,
type
=
float
,
default
=
6.0
,
help
=
"Embedded classifier free guidance scale."
,
)
parser
.
add_argument
(
"--flow_shift"
,
type
=
int
,
default
=
7
,
help
=
"Flow shift parameter."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size for inference."
)
parser
.
add_argument
(
"--num_videos"
,
type
=
int
,
default
=
1
,
help
=
"Number of videos to generate per prompt."
,
)
parser
.
add_argument
(
"--load-key"
,
type
=
str
,
default
=
"module"
,
help
=
"Key to load the model states. 'module' for the main model, 'ema' for the EMA model."
,
)
parser
.
add_argument
(
"--use-cpu-offload"
,
action
=
"store_true"
,
help
=
"Use CPU offload for the model load."
,
)
parser
.
add_argument
(
"--dit-weight"
,
type
=
str
,
default
=
"data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
,
)
parser
.
add_argument
(
"--reproduce"
,
action
=
"store_true"
,
help
=
"Enable reproducibility by setting random seeds and deterministic algorithms."
,
)
parser
.
add_argument
(
"--disable-autocast"
,
action
=
"store_true"
,
help
=
"Disable autocast for denoising loop and vae decoding in pipeline sampling."
,
)
# Flow Matching
parser
.
add_argument
(
"--flow-reverse"
,
action
=
"store_true"
,
help
=
"If reverse, learning/sampling from t=1 -> t=0."
,
)
parser
.
add_argument
(
"--flow-solver"
,
type
=
str
,
default
=
"euler"
,
help
=
"Solver for flow matching."
)
parser
.
add_argument
(
"--use-linear-quadratic-schedule"
,
action
=
"store_true"
,
help
=
"Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)"
,
)
parser
.
add_argument
(
"--linear-schedule-end"
,
type
=
int
,
default
=
25
,
help
=
"End step for linear quadratic schedule for flow matching."
,
)
# Model parameters
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"HYVideo-T/2-cfgdistill"
)
parser
.
add_argument
(
"--latent-channels"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"bf16"
,
choices
=
[
"fp32"
,
"fp16"
,
"bf16"
])
parser
.
add_argument
(
"--rope-theta"
,
type
=
int
,
default
=
256
,
help
=
"Theta used in RoPE."
)
parser
.
add_argument
(
"--vae"
,
type
=
str
,
default
=
"884-16c-hy"
)
parser
.
add_argument
(
"--vae-precision"
,
type
=
str
,
default
=
"fp16"
,
choices
=
[
"fp32"
,
"fp16"
,
"bf16"
])
parser
.
add_argument
(
"--vae-tiling"
,
action
=
"store_true"
,
default
=
True
)
parser
.
add_argument
(
"--vae-sp"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--text-encoder"
,
type
=
str
,
default
=
"llm"
)
parser
.
add_argument
(
"--text-encoder-precision"
,
type
=
str
,
default
=
"fp16"
,
choices
=
[
"fp32"
,
"fp16"
,
"bf16"
],
)
parser
.
add_argument
(
"--text-states-dim"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--text-len"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
default
=
"llm"
)
parser
.
add_argument
(
"--prompt-template"
,
type
=
str
,
default
=
"dit-llm-encode"
)
parser
.
add_argument
(
"--prompt-template-video"
,
type
=
str
,
default
=
"dit-llm-encode-video"
)
parser
.
add_argument
(
"--hidden-state-skip-layer"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--apply-final-norm"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--text-encoder-2"
,
type
=
str
,
default
=
"clipL"
)
parser
.
add_argument
(
"--text-encoder-precision-2"
,
type
=
str
,
default
=
"fp16"
,
choices
=
[
"fp32"
,
"fp16"
,
"bf16"
],
)
parser
.
add_argument
(
"--enable_torch_compile"
,
action
=
"store_true"
,
help
=
"Use torch.compile for speeding up STA inference without teacache"
,
)
parser
.
add_argument
(
"--text-states-dim-2"
,
type
=
int
,
default
=
768
)
parser
.
add_argument
(
"--tokenizer-2"
,
type
=
str
,
default
=
"clipL"
)
parser
.
add_argument
(
"--text-len-2"
,
type
=
int
,
default
=
77
)
args
=
parser
.
parse_args
()
# process for vae sequence parallel
if
args
.
vae_sp
and
not
args
.
vae_tiling
:
raise
ValueError
(
"Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True."
)
main
(
args
)
FastVideo-main/fastvideo/sample/sample_t2v_hunyuan_STA.py
0 → 100644
View file @
c07946d8
import
argparse
import
json
import
os
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
imageio
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torchvision
from
einops
import
rearrange
from
fastvideo.models.hunyuan.inference
import
HunyuanVideoSampler
from
fastvideo.models.hunyuan.modules.modulate_layers
import
modulate
from
fastvideo.utils.parallel_states
import
initialize_sequence_parallel_state
,
nccl_info
def
teacache_forward
(
self
,
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
timestep
:
torch
.
LongTensor
,
encoder_attention_mask
:
torch
.
Tensor
,
mask_strategy
=
None
,
output_features
=
False
,
output_features_stride
=
8
,
attention_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
return_dict
:
bool
=
False
,
guidance
=
None
,
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
if
guidance
is
None
:
guidance
=
torch
.
tensor
([
6016.0
],
device
=
hidden_states
.
device
,
dtype
=
torch
.
bfloat16
)
img
=
x
=
hidden_states
text_mask
=
encoder_attention_mask
t
=
timestep
txt
=
encoder_hidden_states
[:,
1
:]
text_states_2
=
encoder_hidden_states
[:,
0
,
:
self
.
config
.
text_states_dim_2
]
_
,
_
,
ot
,
oh
,
ow
=
x
.
shape
# codespell:ignore
tt
,
th
,
tw
=
(
ot
//
self
.
patch_size
[
0
],
# codespell:ignore
oh
//
self
.
patch_size
[
1
],
# codespell:ignore
ow
//
self
.
patch_size
[
2
],
# codespell:ignore
)
original_tt
=
nccl_info
.
sp_size
*
tt
freqs_cos
,
freqs_sin
=
self
.
get_rotary_pos_embed
((
original_tt
,
th
,
tw
))
# Prepare modulation vectors.
vec
=
self
.
time_in
(
t
)
# text modulation
vec
=
vec
+
self
.
vector_in
(
text_states_2
)
# guidance modulation
if
self
.
guidance_embed
:
if
guidance
is
None
:
raise
ValueError
(
"Didn't get guidance strength for guidance distilled model."
)
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
vec
=
vec
+
self
.
guidance_in
(
guidance
)
# Embed image and text.
img
=
self
.
img_in
(
img
)
if
self
.
text_projection
==
"linear"
:
txt
=
self
.
txt_in
(
txt
)
elif
self
.
text_projection
==
"single_refiner"
:
txt
=
self
.
txt_in
(
txt
,
t
,
text_mask
if
self
.
use_attention_mask
else
None
)
else
:
raise
NotImplementedError
(
f
"Unsupported text_projection:
{
self
.
text_projection
}
"
)
txt_seq_len
=
txt
.
shape
[
1
]
img_seq_len
=
img
.
shape
[
1
]
freqs_cis
=
(
freqs_cos
,
freqs_sin
)
if
freqs_cos
is
not
None
else
None
if
self
.
enable_teacache
:
inp
=
img
.
clone
()
vec_
=
vec
.
clone
()
(
img_mod1_shift
,
img_mod1_scale
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
)
=
self
.
double_blocks
[
0
].
img_mod
(
vec_
).
chunk
(
6
,
dim
=-
1
)
normed_inp
=
self
.
double_blocks
[
0
].
img_norm1
(
inp
)
modulated_inp
=
modulate
(
normed_inp
,
shift
=
img_mod1_shift
,
scale
=
img_mod1_scale
)
if
self
.
cnt
==
0
or
self
.
cnt
==
self
.
num_steps
-
1
:
should_calc
=
True
self
.
accumulated_rel_l1_distance
=
0
else
:
coefficients
=
[
7.33226126e+02
,
-
4.01131952e+02
,
6.75869174e+01
,
-
3.14987800e+00
,
9.61237896e-02
]
rescale_func
=
np
.
poly1d
(
coefficients
)
self
.
accumulated_rel_l1_distance
+=
rescale_func
(
((
modulated_inp
-
self
.
previous_modulated_input
).
abs
().
mean
()
/
self
.
previous_modulated_input
.
abs
().
mean
()).
cpu
().
item
())
if
self
.
accumulated_rel_l1_distance
<
self
.
rel_l1_thresh
:
should_calc
=
False
else
:
should_calc
=
True
self
.
accumulated_rel_l1_distance
=
0
self
.
previous_modulated_input
=
modulated_inp
self
.
cnt
+=
1
if
self
.
cnt
==
self
.
num_steps
:
self
.
cnt
=
0
if
self
.
enable_teacache
:
if
not
should_calc
:
img
+=
self
.
previous_residual
else
:
ori_img
=
img
.
clone
()
# --------------------- Pass through DiT blocks ------------------------
for
index
,
block
in
enumerate
(
self
.
double_blocks
):
double_block_args
=
[
img
,
txt
,
vec
,
freqs_cis
,
text_mask
,
mask_strategy
[
index
]]
img
,
txt
=
block
(
*
double_block_args
)
# Merge txt and img to pass through single stream blocks.
x
=
torch
.
cat
((
img
,
txt
),
1
)
if
output_features
:
features_list
=
[]
if
len
(
self
.
single_blocks
)
>
0
:
for
index
,
block
in
enumerate
(
self
.
single_blocks
):
single_block_args
=
[
x
,
vec
,
txt_seq_len
,
(
freqs_cos
,
freqs_sin
),
text_mask
,
mask_strategy
[
index
+
len
(
self
.
double_blocks
)],
]
x
=
block
(
*
single_block_args
)
if
output_features
and
_
%
output_features_stride
==
0
:
features_list
.
append
(
x
[:,
:
img_seq_len
,
...])
img
=
x
[:,
:
img_seq_len
,
...]
self
.
previous_residual
=
img
-
ori_img
else
:
# --------------------- Pass through DiT blocks ------------------------
for
index
,
block
in
enumerate
(
self
.
double_blocks
):
double_block_args
=
[
img
,
txt
,
vec
,
freqs_cis
,
text_mask
,
mask_strategy
[
index
]]
img
,
txt
=
block
(
*
double_block_args
)
# Merge txt and img to pass through single stream blocks.
x
=
torch
.
cat
((
img
,
txt
),
1
)
if
output_features
:
features_list
=
[]
if
len
(
self
.
single_blocks
)
>
0
:
for
index
,
block
in
enumerate
(
self
.
single_blocks
):
single_block_args
=
[
x
,
vec
,
txt_seq_len
,
(
freqs_cos
,
freqs_sin
),
text_mask
,
mask_strategy
[
index
+
len
(
self
.
double_blocks
)],
]
x
=
block
(
*
single_block_args
)
if
output_features
and
_
%
output_features_stride
==
0
:
features_list
.
append
(
x
[:,
:
img_seq_len
,
...])
img
=
x
[:,
:
img_seq_len
,
...]
# ---------------------------- Final layer ------------------------------
img
=
self
.
final_layer
(
img
,
vec
)
# (N, T, patch_size ** 2 * out_channels)
img
=
self
.
unpatchify
(
img
,
tt
,
th
,
tw
)
assert
not
return_dict
,
"return_dict is not supported."
if
output_features
:
features_list
=
torch
.
stack
(
features_list
,
dim
=
0
)
else
:
features_list
=
None
return
(
img
,
features_list
)
def
initialize_distributed
():
local_rank
=
int
(
os
.
getenv
(
"RANK"
,
0
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
1
))
print
(
"world_size"
,
world_size
)
torch
.
cuda
.
set_device
(
local_rank
)
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
,
world_size
=
world_size
,
rank
=
local_rank
)
initialize_sequence_parallel_state
(
world_size
)
def
main
(
args
):
initialize_distributed
()
print
(
nccl_info
.
sp_size
)
print
(
args
)
models_root_path
=
Path
(
args
.
model_path
)
if
not
models_root_path
.
exists
():
raise
ValueError
(
f
"`models_root` not exists:
{
models_root_path
}
"
)
# Create save folder to save the samples
save_path
=
args
.
output_path
os
.
makedirs
(
os
.
path
.
dirname
(
save_path
),
exist_ok
=
True
)
# Load models
hunyuan_video_sampler
=
HunyuanVideoSampler
.
from_pretrained
(
models_root_path
,
args
=
args
)
# Get the updated args
args
=
hunyuan_video_sampler
.
args
# teacache
hunyuan_video_sampler
.
pipeline
.
transformer
.
__class__
.
enable_teacache
=
args
.
enable_teacache
hunyuan_video_sampler
.
pipeline
.
transformer
.
__class__
.
cnt
=
0
hunyuan_video_sampler
.
pipeline
.
transformer
.
__class__
.
num_steps
=
args
.
num_inference_steps
hunyuan_video_sampler
.
pipeline
.
transformer
.
__class__
.
rel_l1_thresh
=
args
.
rel_l1_thresh
# 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
hunyuan_video_sampler
.
pipeline
.
transformer
.
__class__
.
accumulated_rel_l1_distance
=
0
hunyuan_video_sampler
.
pipeline
.
transformer
.
__class__
.
previous_modulated_input
=
None
hunyuan_video_sampler
.
pipeline
.
transformer
.
__class__
.
previous_residual
=
None
hunyuan_video_sampler
.
pipeline
.
transformer
.
__class__
.
forward
=
teacache_forward
with
open
(
args
.
mask_strategy_file_path
,
'r'
)
as
f
:
mask_strategy
=
json
.
load
(
f
)
if
args
.
prompt
.
endswith
(
'.txt'
):
with
open
(
args
.
prompt
)
as
f
:
prompts
=
[
line
.
strip
()
for
line
in
f
.
readlines
()]
else
:
prompts
=
[
args
.
prompt
]
for
prompt
in
prompts
:
outputs
=
hunyuan_video_sampler
.
predict
(
prompt
=
prompt
,
height
=
args
.
height
,
width
=
args
.
width
,
video_length
=
args
.
num_frames
,
seed
=
args
.
seed
,
negative_prompt
=
args
.
neg_prompt
,
infer_steps
=
args
.
num_inference_steps
,
guidance_scale
=
args
.
guidance_scale
,
num_videos_per_prompt
=
args
.
num_videos
,
flow_shift
=
args
.
flow_shift
,
batch_size
=
args
.
batch_size
,
embedded_guidance_scale
=
args
.
embedded_cfg_scale
,
mask_strategy
=
mask_strategy
,
)
videos
=
rearrange
(
outputs
[
"samples"
],
"b c t h w -> t b c h w"
)
outputs
=
[]
for
x
in
videos
:
x
=
torchvision
.
utils
.
make_grid
(
x
,
nrow
=
6
)
x
=
x
.
transpose
(
0
,
1
).
transpose
(
1
,
2
).
squeeze
(
-
1
)
outputs
.
append
((
x
*
255
).
numpy
().
astype
(
np
.
uint8
))
os
.
makedirs
(
os
.
path
.
dirname
(
args
.
output_path
),
exist_ok
=
True
)
imageio
.
mimsave
(
os
.
path
.
join
(
args
.
output_path
,
f
"
{
prompt
[:
100
]
}
.mp4"
),
outputs
,
fps
=
args
.
fps
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
# Basic parameters
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
help
=
"prompt file for inference"
)
parser
.
add_argument
(
"--num_frames"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--height"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--width"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--num_inference_steps"
,
type
=
int
,
default
=
50
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
default
=
"data/hunyuan"
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
default
=
"./outputs/video"
)
parser
.
add_argument
(
"--fps"
,
type
=
int
,
default
=
24
)
# Additional parameters
parser
.
add_argument
(
"--sliding_block_size"
,
type
=
str
,
default
=
"8,6,10"
,
help
=
"Sliding block size for sliding block attention."
,
)
parser
.
add_argument
(
"--denoise-type"
,
type
=
str
,
default
=
"flow"
,
help
=
"Denoise type for noised inputs."
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
None
,
help
=
"Seed for evaluation."
)
parser
.
add_argument
(
"--neg_prompt"
,
type
=
str
,
default
=
None
,
help
=
"Negative prompt for sampling."
)
parser
.
add_argument
(
"--guidance_scale"
,
type
=
float
,
default
=
1.0
,
help
=
"Classifier free guidance scale."
,
)
parser
.
add_argument
(
"--embedded_cfg_scale"
,
type
=
float
,
default
=
6.0
,
help
=
"Embedded classifier free guidance scale."
,
)
parser
.
add_argument
(
"--flow_shift"
,
type
=
int
,
default
=
7
,
help
=
"Flow shift parameter."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size for inference."
)
parser
.
add_argument
(
"--num_videos"
,
type
=
int
,
default
=
1
,
help
=
"Number of videos to generate per prompt."
,
)
parser
.
add_argument
(
"--load-key"
,
type
=
str
,
default
=
"module"
,
help
=
"Key to load the model states. 'module' for the main model, 'ema' for the EMA model."
,
)
parser
.
add_argument
(
"--use-cpu-offload"
,
action
=
"store_true"
,
help
=
"Use CPU offload for the model load."
,
)
parser
.
add_argument
(
"--dit-weight"
,
type
=
str
,
default
=
"data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
,
)
parser
.
add_argument
(
"--reproduce"
,
action
=
"store_true"
,
help
=
"Enable reproducibility by setting random seeds and deterministic algorithms."
,
)
parser
.
add_argument
(
"--disable-autocast"
,
action
=
"store_true"
,
help
=
"Disable autocast for denoising loop and vae decoding in pipeline sampling."
,
)
# Flow Matching
parser
.
add_argument
(
"--flow-reverse"
,
action
=
"store_true"
,
help
=
"If reverse, learning/sampling from t=1 -> t=0."
,
)
parser
.
add_argument
(
"--flow-solver"
,
type
=
str
,
default
=
"euler"
,
help
=
"Solver for flow matching."
)
parser
.
add_argument
(
"--use-linear-quadratic-schedule"
,
action
=
"store_true"
,
help
=
"Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)"
,
)
parser
.
add_argument
(
"--linear-schedule-end"
,
type
=
int
,
default
=
25
,
help
=
"End step for linear quadratic schedule for flow matching."
,
)
# Model parameters
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"HYVideo-T/2-cfgdistill"
)
parser
.
add_argument
(
"--latent-channels"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"bf16"
,
choices
=
[
"fp32"
,
"fp16"
,
"bf16"
])
parser
.
add_argument
(
"--rope-theta"
,
type
=
int
,
default
=
256
,
help
=
"Theta used in RoPE."
)
parser
.
add_argument
(
"--vae"
,
type
=
str
,
default
=
"884-16c-hy"
)
parser
.
add_argument
(
"--vae-precision"
,
type
=
str
,
default
=
"fp16"
,
choices
=
[
"fp32"
,
"fp16"
,
"bf16"
])
parser
.
add_argument
(
"--vae-tiling"
,
action
=
"store_true"
,
default
=
True
)
parser
.
add_argument
(
"--vae-sp"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--text-encoder"
,
type
=
str
,
default
=
"llm"
)
parser
.
add_argument
(
"--text-encoder-precision"
,
type
=
str
,
default
=
"fp16"
,
choices
=
[
"fp32"
,
"fp16"
,
"bf16"
],
)
parser
.
add_argument
(
"--text-states-dim"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--text-len"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
default
=
"llm"
)
parser
.
add_argument
(
"--prompt-template"
,
type
=
str
,
default
=
"dit-llm-encode"
)
parser
.
add_argument
(
"--prompt-template-video"
,
type
=
str
,
default
=
"dit-llm-encode-video"
)
parser
.
add_argument
(
"--hidden-state-skip-layer"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--apply-final-norm"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--text-encoder-2"
,
type
=
str
,
default
=
"clipL"
)
parser
.
add_argument
(
"--text-encoder-precision-2"
,
type
=
str
,
default
=
"fp16"
,
choices
=
[
"fp32"
,
"fp16"
,
"bf16"
],
)
parser
.
add_argument
(
"--text-states-dim-2"
,
type
=
int
,
default
=
768
)
parser
.
add_argument
(
"--tokenizer-2"
,
type
=
str
,
default
=
"clipL"
)
parser
.
add_argument
(
"--text-len-2"
,
type
=
int
,
default
=
77
)
parser
.
add_argument
(
"--skip_time_steps"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--mask_strategy_selected"
,
type
=
lambda
x
:
[
int
(
i
)
for
i
in
x
.
strip
(
'[]'
).
split
(
','
)],
# Convert string to list of integers
default
=
[
1
,
2
,
6
],
# Now can be directly set as a list
help
=
"order of candidates"
)
parser
.
add_argument
(
"--rel_l1_thresh"
,
type
=
float
,
default
=
0.15
,
help
=
"0.1 for 1.6x speedup, 0.15 for 2.1x speedup"
,
)
parser
.
add_argument
(
"--enable_teacache"
,
action
=
"store_true"
,
help
=
"Use teacache for speeding up inference"
,
)
parser
.
add_argument
(
"--enable_torch_compile"
,
action
=
"store_true"
,
help
=
"Use torch.compile for speeding up STA inference without teacache"
,
)
parser
.
add_argument
(
"--mask_strategy_file_path"
,
type
=
str
,
default
=
"assets/mask_strategy.json"
)
args
=
parser
.
parse_args
()
# process for vae sequence parallel
if
args
.
vae_sp
and
not
args
.
vae_tiling
:
raise
ValueError
(
"Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True."
)
if
args
.
enable_teacache
and
args
.
enable_torch_compile
:
raise
ValueError
(
"--enable_teacache and --enable_torch_compile cannot be used simultaneously. Please enable only one of these options."
)
main
(
args
)
FastVideo-main/fastvideo/sample/sample_t2v_hunyuan_hf.py
0 → 100644
View file @
c07946d8
import
argparse
import
json
import
os
import
time
import
torch
import
torch.distributed
as
dist
from
diffusers
import
BitsAndBytesConfig
from
diffusers.utils
import
export_to_video
from
fastvideo.models.hunyuan_hf.modeling_hunyuan
import
HunyuanVideoTransformer3DModel
from
fastvideo.models.hunyuan_hf.pipeline_hunyuan
import
HunyuanVideoPipeline
from
fastvideo.utils.parallel_states
import
initialize_sequence_parallel_state
,
nccl_info
def
initialize_distributed
():
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
local_rank
=
int
(
os
.
getenv
(
"RANK"
,
0
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
1
))
print
(
"world_size"
,
world_size
)
torch
.
cuda
.
set_device
(
local_rank
)
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
,
world_size
=
world_size
,
rank
=
local_rank
)
initialize_sequence_parallel_state
(
world_size
)
def
inference
(
args
):
initialize_distributed
()
print
(
nccl_info
.
sp_size
)
device
=
torch
.
cuda
.
current_device
()
# Peiyuan: GPU seed will cause A100 and H100 to produce different results .....
weight_dtype
=
torch
.
bfloat16
if
args
.
transformer_path
is
not
None
:
transformer
=
HunyuanVideoTransformer3DModel
.
from_pretrained
(
args
.
transformer_path
)
else
:
transformer
=
HunyuanVideoTransformer3DModel
.
from_pretrained
(
args
.
model_path
,
subfolder
=
"transformer/"
,
torch_dtype
=
weight_dtype
)
pipe
=
HunyuanVideoPipeline
.
from_pretrained
(
args
.
model_path
,
transformer
=
transformer
,
torch_dtype
=
weight_dtype
)
pipe
.
enable_vae_tiling
()
if
args
.
lora_checkpoint_dir
is
not
None
:
print
(
f
"Loading LoRA weights from
{
args
.
lora_checkpoint_dir
}
"
)
config_path
=
os
.
path
.
join
(
args
.
lora_checkpoint_dir
,
"lora_config.json"
)
with
open
(
config_path
,
"r"
)
as
f
:
lora_config_dict
=
json
.
load
(
f
)
rank
=
lora_config_dict
[
"lora_params"
][
"lora_rank"
]
lora_alpha
=
lora_config_dict
[
"lora_params"
][
"lora_alpha"
]
lora_scaling
=
lora_alpha
/
rank
pipe
.
load_lora_weights
(
args
.
lora_checkpoint_dir
,
adapter_name
=
"default"
)
pipe
.
set_adapters
([
"default"
],
[
lora_scaling
])
print
(
f
"Successfully Loaded LoRA weights from
{
args
.
lora_checkpoint_dir
}
"
)
if
args
.
cpu_offload
:
pipe
.
enable_model_cpu_offload
(
device
)
else
:
pipe
.
to
(
device
)
# Generate videos from the input prompt
if
args
.
prompt_embed_path
is
not
None
:
prompt_embeds
=
(
torch
.
load
(
args
.
prompt_embed_path
,
map_location
=
"cpu"
,
weights_only
=
True
).
to
(
device
).
unsqueeze
(
0
))
encoder_attention_mask
=
(
torch
.
load
(
args
.
encoder_attention_mask_path
,
map_location
=
"cpu"
,
weights_only
=
True
).
to
(
device
).
unsqueeze
(
0
))
prompts
=
None
elif
args
.
prompt_path
is
not
None
:
prompts
=
[
line
.
strip
()
for
line
in
open
(
args
.
prompt_path
,
"r"
)]
prompt_embeds
=
None
encoder_attention_mask
=
None
else
:
prompts
=
args
.
prompts
prompt_embeds
=
None
encoder_attention_mask
=
None
if
prompts
is
not
None
:
with
torch
.
autocast
(
"cuda"
,
dtype
=
torch
.
bfloat16
):
for
prompt
in
prompts
:
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
args
.
seed
)
video
=
pipe
(
prompt
=
[
prompt
],
height
=
args
.
height
,
width
=
args
.
width
,
num_frames
=
args
.
num_frames
,
num_inference_steps
=
args
.
num_inference_steps
,
generator
=
generator
,
).
frames
if
nccl_info
.
global_rank
<=
0
:
os
.
makedirs
(
args
.
output_path
,
exist_ok
=
True
)
suffix
=
prompt
.
split
(
"."
)[
0
]
export_to_video
(
video
[
0
],
os
.
path
.
join
(
args
.
output_path
,
f
"
{
suffix
}
.mp4"
),
fps
=
24
,
)
else
:
with
torch
.
autocast
(
"cuda"
,
dtype
=
torch
.
bfloat16
):
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
args
.
seed
)
videos
=
pipe
(
prompt_embeds
=
prompt_embeds
,
prompt_attention_mask
=
encoder_attention_mask
,
height
=
args
.
height
,
width
=
args
.
width
,
num_frames
=
args
.
num_frames
,
num_inference_steps
=
args
.
num_inference_steps
,
generator
=
generator
,
).
frames
if
nccl_info
.
global_rank
<=
0
:
export_to_video
(
videos
[
0
],
args
.
output_path
+
".mp4"
,
fps
=
24
)
def
inference_quantization
(
args
):
torch
.
manual_seed
(
args
.
seed
)
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
model_id
=
args
.
model_path
if
args
.
quantization
==
"nf4"
:
quantization_config
=
BitsAndBytesConfig
(
load_in_4bit
=
True
,
bnb_4bit_compute_dtype
=
torch
.
bfloat16
,
bnb_4bit_quant_type
=
"nf4"
,
llm_int8_skip_modules
=
[
"proj_out"
,
"norm_out"
])
transformer
=
HunyuanVideoTransformer3DModel
.
from_pretrained
(
model_id
,
subfolder
=
"transformer/"
,
torch_dtype
=
torch
.
bfloat16
,
quantization_config
=
quantization_config
)
if
args
.
quantization
==
"int8"
:
quantization_config
=
BitsAndBytesConfig
(
load_in_8bit
=
True
,
llm_int8_skip_modules
=
[
"proj_out"
,
"norm_out"
])
transformer
=
HunyuanVideoTransformer3DModel
.
from_pretrained
(
model_id
,
subfolder
=
"transformer/"
,
torch_dtype
=
torch
.
bfloat16
,
quantization_config
=
quantization_config
)
elif
not
args
.
quantization
:
transformer
=
HunyuanVideoTransformer3DModel
.
from_pretrained
(
model_id
,
subfolder
=
"transformer/"
,
torch_dtype
=
torch
.
bfloat16
).
to
(
device
)
print
(
"Max vram for read transformer:"
,
round
(
torch
.
cuda
.
max_memory_allocated
(
device
=
"cuda"
)
/
1024
**
3
,
3
),
"GiB"
)
torch
.
cuda
.
reset_max_memory_allocated
(
device
)
if
not
args
.
cpu_offload
:
pipe
=
HunyuanVideoPipeline
.
from_pretrained
(
model_id
,
torch_dtype
=
torch
.
bfloat16
).
to
(
device
)
pipe
.
transformer
=
transformer
else
:
pipe
=
HunyuanVideoPipeline
.
from_pretrained
(
model_id
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
)
torch
.
cuda
.
reset_max_memory_allocated
(
device
)
pipe
.
scheduler
.
_shift
=
args
.
flow_shift
pipe
.
vae
.
enable_tiling
()
if
args
.
cpu_offload
:
pipe
.
enable_model_cpu_offload
()
print
(
"Max vram for init pipeline:"
,
round
(
torch
.
cuda
.
max_memory_allocated
(
device
=
"cuda"
)
/
1024
**
3
,
3
),
"GiB"
)
if
args
.
prompt
.
endswith
(
'.txt'
):
with
open
(
args
.
prompt
)
as
f
:
prompts
=
[
line
.
strip
()
for
line
in
f
.
readlines
()]
else
:
prompts
=
[
args
.
prompt
]
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
args
.
seed
)
os
.
makedirs
(
os
.
path
.
dirname
(
args
.
output_path
),
exist_ok
=
True
)
torch
.
cuda
.
reset_max_memory_allocated
(
device
)
for
prompt
in
prompts
:
start_time
=
time
.
perf_counter
()
output
=
pipe
(
prompt
=
prompt
,
height
=
args
.
height
,
width
=
args
.
width
,
num_frames
=
args
.
num_frames
,
num_inference_steps
=
args
.
num_inference_steps
,
generator
=
generator
,
).
frames
[
0
]
export_to_video
(
output
,
os
.
path
.
join
(
args
.
output_path
,
f
"
{
prompt
[:
100
]
}
.mp4"
),
fps
=
args
.
fps
)
print
(
"Time:"
,
round
(
time
.
perf_counter
()
-
start_time
,
2
),
"seconds"
)
print
(
"Max vram for denoise:"
,
round
(
torch
.
cuda
.
max_memory_allocated
(
device
=
"cuda"
)
/
1024
**
3
,
3
),
"GiB"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
# Basic parameters
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
help
=
"prompt file for inference"
)
parser
.
add_argument
(
"--prompt_embed_path"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--prompt_path"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--num_frames"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--height"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--width"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--num_inference_steps"
,
type
=
int
,
default
=
50
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
default
=
"data/hunyuan"
)
parser
.
add_argument
(
"--transformer_path"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
default
=
"./outputs/video"
)
parser
.
add_argument
(
"--fps"
,
type
=
int
,
default
=
24
)
parser
.
add_argument
(
"--quantization"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--cpu_offload"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--lora_checkpoint_dir"
,
type
=
str
,
default
=
None
,
help
=
"Path to the directory containing LoRA checkpoints"
,
)
# Additional parameters
parser
.
add_argument
(
"--denoise-type"
,
type
=
str
,
default
=
"flow"
,
help
=
"Denoise type for noised inputs."
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
None
,
help
=
"Seed for evaluation."
)
parser
.
add_argument
(
"--neg_prompt"
,
type
=
str
,
default
=
None
,
help
=
"Negative prompt for sampling."
)
parser
.
add_argument
(
"--guidance_scale"
,
type
=
float
,
default
=
1.0
,
help
=
"Classifier free guidance scale."
,
)
parser
.
add_argument
(
"--embedded_cfg_scale"
,
type
=
float
,
default
=
6.0
,
help
=
"Embedded classifier free guidance scale."
,
)
parser
.
add_argument
(
"--flow_shift"
,
type
=
int
,
default
=
7
,
help
=
"Flow shift parameter."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size for inference."
)
parser
.
add_argument
(
"--num_videos"
,
type
=
int
,
default
=
1
,
help
=
"Number of videos to generate per prompt."
,
)
parser
.
add_argument
(
"--load-key"
,
type
=
str
,
default
=
"module"
,
help
=
"Key to load the model states. 'module' for the main model, 'ema' for the EMA model."
,
)
parser
.
add_argument
(
"--dit-weight"
,
type
=
str
,
default
=
"data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
,
)
parser
.
add_argument
(
"--reproduce"
,
action
=
"store_true"
,
help
=
"Enable reproducibility by setting random seeds and deterministic algorithms."
,
)
parser
.
add_argument
(
"--disable-autocast"
,
action
=
"store_true"
,
help
=
"Disable autocast for denoising loop and vae decoding in pipeline sampling."
,
)
# Flow Matching
parser
.
add_argument
(
"--flow-reverse"
,
action
=
"store_true"
,
help
=
"If reverse, learning/sampling from t=1 -> t=0."
,
)
parser
.
add_argument
(
"--flow-solver"
,
type
=
str
,
default
=
"euler"
,
help
=
"Solver for flow matching."
)
parser
.
add_argument
(
"--use-linear-quadratic-schedule"
,
action
=
"store_true"
,
help
=
"Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)"
,
)
parser
.
add_argument
(
"--linear-schedule-end"
,
type
=
int
,
default
=
25
,
help
=
"End step for linear quadratic schedule for flow matching."
,
)
# Model parameters
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"HYVideo-T/2-cfgdistill"
)
parser
.
add_argument
(
"--latent-channels"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"bf16"
,
choices
=
[
"fp32"
,
"fp16"
,
"bf16"
,
"fp8"
])
parser
.
add_argument
(
"--rope-theta"
,
type
=
int
,
default
=
256
,
help
=
"Theta used in RoPE."
)
parser
.
add_argument
(
"--vae"
,
type
=
str
,
default
=
"884-16c-hy"
)
parser
.
add_argument
(
"--vae-precision"
,
type
=
str
,
default
=
"fp16"
,
choices
=
[
"fp32"
,
"fp16"
,
"bf16"
])
parser
.
add_argument
(
"--vae-tiling"
,
action
=
"store_true"
,
default
=
True
)
parser
.
add_argument
(
"--text-encoder"
,
type
=
str
,
default
=
"llm"
)
parser
.
add_argument
(
"--text-encoder-precision"
,
type
=
str
,
default
=
"fp16"
,
choices
=
[
"fp32"
,
"fp16"
,
"bf16"
],
)
parser
.
add_argument
(
"--text-states-dim"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--text-len"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
default
=
"llm"
)
parser
.
add_argument
(
"--prompt-template"
,
type
=
str
,
default
=
"dit-llm-encode"
)
parser
.
add_argument
(
"--prompt-template-video"
,
type
=
str
,
default
=
"dit-llm-encode-video"
)
parser
.
add_argument
(
"--hidden-state-skip-layer"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--apply-final-norm"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--text-encoder-2"
,
type
=
str
,
default
=
"clipL"
)
parser
.
add_argument
(
"--text-encoder-precision-2"
,
type
=
str
,
default
=
"fp16"
,
choices
=
[
"fp32"
,
"fp16"
,
"bf16"
],
)
parser
.
add_argument
(
"--text-states-dim-2"
,
type
=
int
,
default
=
768
)
parser
.
add_argument
(
"--tokenizer-2"
,
type
=
str
,
default
=
"clipL"
)
parser
.
add_argument
(
"--text-len-2"
,
type
=
int
,
default
=
77
)
args
=
parser
.
parse_args
()
if
args
.
quantization
:
inference_quantization
(
args
)
else
:
inference
(
args
)
FastVideo-main/fastvideo/sample/sample_t2v_mochi.py
0 → 100644
View file @
c07946d8
import
argparse
import
json
import
os
import
torch
import
torch.distributed
as
dist
from
diffusers
import
FlowMatchEulerDiscreteScheduler
from
diffusers.utils
import
export_to_video
from
fastvideo.distill.solver
import
PCMFMScheduler
from
fastvideo.models.mochi_hf.modeling_mochi
import
MochiTransformer3DModel
from
fastvideo.models.mochi_hf.pipeline_mochi
import
MochiPipeline
from
fastvideo.utils.parallel_states
import
initialize_sequence_parallel_state
,
nccl_info
def
initialize_distributed
():
local_rank
=
int
(
os
.
getenv
(
"RANK"
,
0
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
1
))
print
(
"world_size"
,
world_size
)
torch
.
cuda
.
set_device
(
local_rank
)
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
,
world_size
=
world_size
,
rank
=
local_rank
)
initialize_sequence_parallel_state
(
world_size
)
def
main
(
args
):
initialize_distributed
()
print
(
nccl_info
.
sp_size
)
device
=
torch
.
cuda
.
current_device
()
# Peiyuan: GPU seed will cause A100 and H100 to produce different results .....
if
args
.
scheduler_type
==
"euler"
:
scheduler
=
FlowMatchEulerDiscreteScheduler
()
else
:
linear_quadratic
=
True
if
"linear_quadratic"
in
args
.
scheduler_type
else
False
scheduler
=
PCMFMScheduler
(
1000
,
args
.
shift
,
args
.
num_euler_timesteps
,
linear_quadratic
,
args
.
linear_threshold
,
args
.
linear_range
,
)
if
args
.
transformer_path
is
not
None
:
transformer
=
MochiTransformer3DModel
.
from_pretrained
(
args
.
transformer_path
)
else
:
transformer
=
MochiTransformer3DModel
.
from_pretrained
(
args
.
model_path
,
subfolder
=
"transformer/"
)
pipe
=
MochiPipeline
.
from_pretrained
(
args
.
model_path
,
transformer
=
transformer
,
scheduler
=
scheduler
)
pipe
.
enable_vae_tiling
()
if
args
.
lora_checkpoint_dir
is
not
None
:
print
(
f
"Loading LoRA weights from
{
args
.
lora_checkpoint_dir
}
"
)
config_path
=
os
.
path
.
join
(
args
.
lora_checkpoint_dir
,
"lora_config.json"
)
with
open
(
config_path
,
"r"
)
as
f
:
lora_config_dict
=
json
.
load
(
f
)
rank
=
lora_config_dict
[
"lora_params"
][
"lora_rank"
]
lora_alpha
=
lora_config_dict
[
"lora_params"
][
"lora_alpha"
]
lora_scaling
=
lora_alpha
/
rank
pipe
.
load_lora_weights
(
args
.
lora_checkpoint_dir
,
adapter_name
=
"default"
)
pipe
.
set_adapters
([
"default"
],
[
lora_scaling
])
print
(
f
"Successfully Loaded LoRA weights from
{
args
.
lora_checkpoint_dir
}
"
)
# pipe.to(device)
pipe
.
enable_model_cpu_offload
(
device
)
# Generate videos from the input prompt
if
args
.
prompt_embed_path
is
not
None
:
prompt_embeds
=
(
torch
.
load
(
args
.
prompt_embed_path
,
map_location
=
"cpu"
,
weights_only
=
True
).
to
(
device
).
unsqueeze
(
0
))
encoder_attention_mask
=
(
torch
.
load
(
args
.
encoder_attention_mask_path
,
map_location
=
"cpu"
,
weights_only
=
True
).
to
(
device
).
unsqueeze
(
0
))
prompts
=
None
elif
args
.
prompt_path
is
not
None
:
prompts
=
[
line
.
strip
()
for
line
in
open
(
args
.
prompt_path
,
"r"
)]
prompt_embeds
=
None
encoder_attention_mask
=
None
else
:
prompts
=
args
.
prompts
prompt_embeds
=
None
encoder_attention_mask
=
None
if
prompts
is
not
None
:
with
torch
.
autocast
(
"cuda"
,
dtype
=
torch
.
bfloat16
):
for
prompt
in
prompts
:
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
args
.
seed
)
video
=
pipe
(
prompt
=
[
prompt
],
height
=
args
.
height
,
width
=
args
.
width
,
num_frames
=
args
.
num_frames
,
num_inference_steps
=
args
.
num_inference_steps
,
guidance_scale
=
args
.
guidance_scale
,
generator
=
generator
,
).
frames
if
nccl_info
.
global_rank
<=
0
:
os
.
makedirs
(
args
.
output_path
,
exist_ok
=
True
)
suffix
=
prompt
.
split
(
"."
)[
0
]
export_to_video
(
video
[
0
],
os
.
path
.
join
(
args
.
output_path
,
f
"
{
suffix
}
.mp4"
),
fps
=
30
,
)
else
:
with
torch
.
autocast
(
"cuda"
,
dtype
=
torch
.
bfloat16
):
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
args
.
seed
)
videos
=
pipe
(
prompt_embeds
=
prompt_embeds
,
prompt_attention_mask
=
encoder_attention_mask
,
height
=
args
.
height
,
width
=
args
.
width
,
num_frames
=
args
.
num_frames
,
num_inference_steps
=
args
.
num_inference_steps
,
guidance_scale
=
args
.
guidance_scale
,
generator
=
generator
,
).
frames
if
nccl_info
.
global_rank
<=
0
:
export_to_video
(
videos
[
0
],
args
.
output_path
+
".mp4"
,
fps
=
30
)
if
__name__
==
"__main__"
:
# arg parse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--prompts"
,
nargs
=
"+"
,
default
=
[])
parser
.
add_argument
(
"--num_frames"
,
type
=
int
,
default
=
163
)
parser
.
add_argument
(
"--height"
,
type
=
int
,
default
=
480
)
parser
.
add_argument
(
"--width"
,
type
=
int
,
default
=
848
)
parser
.
add_argument
(
"--num_inference_steps"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--guidance_scale"
,
type
=
float
,
default
=
4.5
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
default
=
"data/mochi"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
default
=
"./outputs.mp4"
)
parser
.
add_argument
(
"--transformer_path"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--prompt_embed_path"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--prompt_path"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--scheduler_type"
,
type
=
str
,
default
=
"euler"
)
parser
.
add_argument
(
"--encoder_attention_mask_path"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--lora_checkpoint_dir"
,
type
=
str
,
default
=
None
,
help
=
"Path to the directory containing LoRA checkpoints"
,
)
parser
.
add_argument
(
"--shift"
,
type
=
float
,
default
=
8.0
)
parser
.
add_argument
(
"--num_euler_timesteps"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--linear_threshold"
,
type
=
float
,
default
=
0.025
)
parser
.
add_argument
(
"--linear_range"
,
type
=
float
,
default
=
0.5
)
args
=
parser
.
parse_args
()
main
(
args
)
FastVideo-main/fastvideo/sample/sample_t2v_mochi_no_sp.py
0 → 100644
View file @
c07946d8
import
argparse
import
torch
from
diffusers
import
FlowMatchEulerDiscreteScheduler
from
diffusers.utils
import
export_to_video
from
fastvideo.models.mochi_hf.modeling_mochi
import
MochiTransformer3DModel
from
fastvideo.models.mochi_hf.pipeline_mochi
import
MochiPipeline
def
main
(
args
):
# Set the random seed for reproducibility
generator
=
torch
.
Generator
(
"cuda"
).
manual_seed
(
args
.
seed
)
# do not invert
scheduler
=
FlowMatchEulerDiscreteScheduler
()
if
args
.
transformer_path
is
not
None
:
transformer
=
MochiTransformer3DModel
.
from_pretrained
(
args
.
transformer_path
)
else
:
transformer
=
MochiTransformer3DModel
.
from_pretrained
(
args
.
model_path
,
subfolder
=
"transformer/"
)
pipe
=
MochiPipeline
.
from_pretrained
(
args
.
model_path
,
transformer
=
transformer
,
scheduler
=
scheduler
)
pipe
.
enable_vae_tiling
()
# pipe.to("cuda:1")
pipe
.
enable_model_cpu_offload
()
# Generate videos from the input prompt
with
torch
.
autocast
(
"cuda"
,
dtype
=
torch
.
bfloat16
):
videos
=
pipe
(
prompt
=
args
.
prompts
,
height
=
args
.
height
,
width
=
args
.
width
,
num_frames
=
args
.
num_frames
,
generator
=
generator
,
num_inference_steps
=
args
.
num_inference_steps
,
guidance_scale
=
args
.
guidance_scale
,
).
frames
for
prompt
,
video
in
zip
(
args
.
prompts
,
videos
):
export_to_video
(
video
,
args
.
output_path
+
f
"_
{
prompt
}
.mp4"
,
fps
=
30
)
if
__name__
==
"__main__"
:
# arg parse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--prompts"
,
nargs
=
"+"
,
default
=
[])
parser
.
add_argument
(
"--num_frames"
,
type
=
int
,
default
=
163
)
parser
.
add_argument
(
"--height"
,
type
=
int
,
default
=
480
)
parser
.
add_argument
(
"--width"
,
type
=
int
,
default
=
848
)
parser
.
add_argument
(
"--num_inference_steps"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--guidance_scale"
,
type
=
float
,
default
=
4.5
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
default
=
"data/mochi"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
12345
)
parser
.
add_argument
(
"--transformer_path"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
default
=
"./outputs.mp4"
)
args
=
parser
.
parse_args
()
main
(
args
)
FastVideo-main/fastvideo/sample/sample_t2v_stepvideo.py
0 → 100644
View file @
c07946d8
import
argparse
import
os
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
fastvideo.models.stepvideo.diffusion.scheduler
import
FlowMatchDiscreteScheduler
from
fastvideo.models.stepvideo.diffusion.video_pipeline
import
StepVideoPipeline
from
fastvideo.models.stepvideo.modules.model
import
StepVideoModel
from
fastvideo.models.stepvideo.utils
import
setup_seed
from
fastvideo.models.stepvideo.utils.quantization
import
convert_fp8_linear
,
fp8_linear_forward
from
fastvideo.utils.logging_
import
main_print
from
fastvideo.utils.parallel_states
import
initialize_sequence_parallel_state
,
nccl_info
def
initialize_distributed
():
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
local_rank
=
int
(
os
.
getenv
(
"RANK"
,
0
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
1
))
print
(
"world_size"
,
world_size
)
torch
.
cuda
.
set_device
(
local_rank
)
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
,
world_size
=
world_size
,
rank
=
local_rank
)
initialize_sequence_parallel_state
(
world_size
)
def
parse_args
(
namespace
=
None
):
parser
=
argparse
.
ArgumentParser
(
description
=
"StepVideo inference script"
)
parser
=
add_extra_models_args
(
parser
)
parser
=
add_denoise_schedule_args
(
parser
)
parser
=
add_inference_args
(
parser
)
args
=
parser
.
parse_args
(
namespace
=
namespace
)
return
args
def
add_extra_models_args
(
parser
:
argparse
.
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
"Extra models args, including vae, text encoders and tokenizers)"
)
group
.
add_argument
(
"--vae_url"
,
type
=
str
,
default
=
'127.0.0.1'
,
help
=
"vae url."
,
)
group
.
add_argument
(
"--caption_url"
,
type
=
str
,
default
=
'127.0.0.1'
,
help
=
"caption url."
,
)
return
parser
def
add_denoise_schedule_args
(
parser
:
argparse
.
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
"Denoise schedule args"
)
# Flow Matching
group
.
add_argument
(
"--time_shift"
,
type
=
float
,
default
=
13
,
help
=
"Shift factor for flow matching schedulers."
,
)
group
.
add_argument
(
"--flow_reverse"
,
action
=
"store_true"
,
help
=
"If reverse, learning/sampling from t=1 -> t=0."
,
)
group
.
add_argument
(
"--flow_solver"
,
type
=
str
,
default
=
"euler"
,
help
=
"Solver for flow matching."
,
)
return
parser
def
add_inference_args
(
parser
:
argparse
.
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
"Inference args"
)
# ======================== Model loads ========================
group
.
add_argument
(
"--model_dir"
,
type
=
str
,
default
=
"./ckpts"
,
help
=
"Root path of all the models, including t2v models and extra models."
,
)
group
.
add_argument
(
"--model_resolution"
,
type
=
str
,
default
=
"540p"
,
choices
=
[
"540p"
],
help
=
"Root path of all the models, including t2v models and extra models."
,
)
group
.
add_argument
(
"--use-cpu-offload"
,
action
=
"store_true"
,
help
=
"Use CPU offload for the model load."
,
)
group
.
add_argument
(
"--use-fp8"
,
action
=
"store_true"
,
help
=
"FP8 Quantization for single GPU support."
,
)
# ======================== Inference general setting ========================
group
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size for inference and evaluation."
,
)
group
.
add_argument
(
"--infer_steps"
,
type
=
int
,
default
=
50
,
help
=
"Number of denoising steps for inference."
,
)
group
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"./results"
,
help
=
"Path to save the generated samples."
,
)
group
.
add_argument
(
"--name_suffix"
,
type
=
str
,
default
=
""
,
help
=
"Suffix for the names of saved samples."
,
)
group
.
add_argument
(
"--num_videos"
,
type
=
int
,
default
=
1
,
help
=
"Number of videos to generate for each prompt."
,
)
# ---sample size---
group
.
add_argument
(
"--num_frames"
,
type
=
int
,
default
=
204
,
help
=
"How many frames to sample from a video. "
,
)
group
.
add_argument
(
"--height"
,
type
=
int
,
default
=
768
,
help
=
"The height of video sample"
,
)
group
.
add_argument
(
"--width"
,
type
=
int
,
default
=
768
,
help
=
"The width of video sample"
,
)
# --- prompt ---
group
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
None
,
help
=
"Prompt for sampling during evaluation."
,
)
group
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1234
,
help
=
"Seed for evaluation."
)
# Classifier-Free Guidance
group
.
add_argument
(
"--pos_magic"
,
type
=
str
,
default
=
"超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。"
,
help
=
"Positive magic prompt for sampling."
)
group
.
add_argument
(
"--neg_magic"
,
type
=
str
,
default
=
"画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。"
,
help
=
"Negative magic prompt for sampling."
)
group
.
add_argument
(
"--cfg_scale"
,
type
=
float
,
default
=
9.0
,
help
=
"Classifier free guidance scale."
)
return
parser
if
__name__
==
"__main__"
:
args
=
parse_args
()
initialize_distributed
()
main_print
(
f
"sequence parallel size:
{
nccl_info
.
sp_size
}
"
)
device
=
torch
.
cuda
.
current_device
()
setup_seed
(
args
.
seed
)
main_print
(
"Loading model, this might take a while..."
)
scheduler
=
FlowMatchDiscreteScheduler
()
if
args
.
use_fp8
:
assert
int
(
os
.
getenv
(
"WORLD_SIZE"
,
1
))
==
1
transformer
=
StepVideoModel
.
from_pretrained
(
os
.
path
.
join
(
args
.
model_dir
,
"transformer"
),
torch_dtype
=
torch
.
bfloat16
,
device
=
"cpu"
)
if
not
os
.
path
.
exists
(
args
.
model_dir
+
"/fp8_transformer.pth"
):
print
(
"no_fp8 weight, creating..."
)
scale_dict
=
convert_fp8_linear
(
transformer
,
torch
.
bfloat16
)
torch
.
save
(
transformer
.
state_dict
(),
args
.
model_dir
+
"/fp8_transformer.pth"
)
torch
.
save
(
scale_dict
,
args
.
model_dir
+
"/fp8_scale_dict.pth"
)
else
:
transformer
.
load_state_dict
(
torch
.
load
(
args
.
model_dir
+
"/fp8_transformer.pth"
))
scale_dict
=
torch
.
load
(
args
.
model_dir
+
"/fp8_scale_dict.pth"
)
original_dtype
=
torch
.
bfloat16
for
key
,
layer
in
transformer
.
named_modules
():
if
isinstance
(
layer
,
nn
.
Linear
)
and
'transformer_blocks'
in
key
and
key
in
scale_dict
:
layer
.
weight
.
data
=
layer
.
weight
.
data
.
to
(
torch
.
float8_e4m3fn
)
print
(
f
"
{
key
}
, layer.weight.dtype:
{
layer
.
weight
.
dtype
}
"
)
original_forward
=
layer
.
forward
scale
=
scale_dict
[
key
]
setattr
(
layer
,
"fp8_scale"
,
scale
.
to
(
dtype
=
original_dtype
))
setattr
(
layer
,
"original_forward"
,
original_forward
)
setattr
(
layer
,
"forward"
,
lambda
input
,
m
=
layer
:
fp8_linear_forward
(
m
,
original_dtype
,
input
))
else
:
transformer
=
StepVideoModel
.
from_pretrained
(
os
.
path
.
join
(
args
.
model_dir
,
"transformer"
),
torch_dtype
=
torch
.
bfloat16
,
device
=
device
)
transformer
=
transformer
.
to
(
device
)
pipeline
=
StepVideoPipeline
(
transformer
,
scheduler
,
save_path
=
args
.
save_path
)
pipeline
.
setup_api
(
vae_url
=
args
.
vae_url
,
caption_url
=
args
.
caption_url
,
)
if
args
.
prompt
.
endswith
(
'.txt'
):
with
open
(
args
.
prompt
)
as
f
:
prompts
=
[
line
.
strip
()
for
line
in
f
.
readlines
()]
else
:
prompts
=
[
args
.
prompt
]
for
prompt
in
prompts
:
videos
=
pipeline
(
prompt
=
prompt
,
num_frames
=
args
.
num_frames
,
height
=
args
.
height
,
width
=
args
.
width
,
num_inference_steps
=
args
.
infer_steps
,
guidance_scale
=
args
.
cfg_scale
,
time_shift
=
args
.
time_shift
,
pos_magic
=
args
.
pos_magic
,
neg_magic
=
args
.
neg_magic
,
output_file_name
=
prompt
[:
50
])
dist
.
destroy_process_group
()
FastVideo-main/fastvideo/sample/sample_t2v_stepvideo_STA.py
0 → 100644
View file @
c07946d8
import
argparse
import
json
import
os
import
types
from
typing
import
Dict
,
Optional
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
einops
import
rearrange
,
repeat
from
fastvideo.models.stepvideo.diffusion.scheduler
import
FlowMatchDiscreteScheduler
from
fastvideo.models.stepvideo.diffusion.video_pipeline
import
StepVideoPipeline
from
fastvideo.models.stepvideo.modules.model
import
StepVideoModel
from
fastvideo.models.stepvideo.utils
import
setup_seed
from
fastvideo.utils.logging_
import
main_print
from
fastvideo.utils.parallel_states
import
initialize_sequence_parallel_state
,
nccl_info
def
initialize_distributed
():
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
local_rank
=
int
(
os
.
getenv
(
"RANK"
,
0
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
1
))
main_print
(
f
"world_size:
{
world_size
}
"
)
torch
.
cuda
.
set_device
(
local_rank
)
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
,
world_size
=
world_size
,
rank
=
local_rank
)
initialize_sequence_parallel_state
(
world_size
)
def
parse_args
(
namespace
=
None
):
parser
=
argparse
.
ArgumentParser
(
description
=
"StepVideo inference script"
)
parser
=
add_extra_models_args
(
parser
)
parser
=
add_denoise_schedule_args
(
parser
)
parser
=
add_inference_args
(
parser
)
args
=
parser
.
parse_args
(
namespace
=
namespace
)
return
args
def
add_extra_models_args
(
parser
:
argparse
.
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
"Extra models args, including vae, text encoders and tokenizers)"
)
group
.
add_argument
(
"--vae_url"
,
type
=
str
,
default
=
'127.0.0.1'
,
help
=
"vae url."
,
)
group
.
add_argument
(
"--caption_url"
,
type
=
str
,
default
=
'127.0.0.1'
,
help
=
"caption url."
,
)
return
parser
def
add_denoise_schedule_args
(
parser
:
argparse
.
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
"Denoise schedule args"
)
# Flow Matching
group
.
add_argument
(
"--time_shift"
,
type
=
float
,
default
=
13
,
help
=
"Shift factor for flow matching schedulers."
,
)
group
.
add_argument
(
"--flow_reverse"
,
action
=
"store_true"
,
help
=
"If reverse, learning/sampling from t=1 -> t=0."
,
)
group
.
add_argument
(
"--flow_solver"
,
type
=
str
,
default
=
"euler"
,
help
=
"Solver for flow matching."
,
)
return
parser
def
add_inference_args
(
parser
:
argparse
.
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
"Inference args"
)
# ======================== Model loads ========================
group
.
add_argument
(
"--model_dir"
,
type
=
str
,
default
=
"./ckpts"
,
help
=
"Root path of all the models, including t2v models and extra models."
,
)
group
.
add_argument
(
"--model_resolution"
,
type
=
str
,
default
=
"540p"
,
choices
=
[
"540p"
],
help
=
"Root path of all the models, including t2v models and extra models."
,
)
group
.
add_argument
(
"--use-cpu-offload"
,
action
=
"store_true"
,
help
=
"Use CPU offload for the model load."
,
)
# ======================== Inference general setting ========================
group
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size for inference and evaluation."
,
)
group
.
add_argument
(
"--infer_steps"
,
type
=
int
,
default
=
50
,
help
=
"Number of denoising steps for inference."
,
)
group
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"./results"
,
help
=
"Path to save the generated samples."
,
)
group
.
add_argument
(
"--name_suffix"
,
type
=
str
,
default
=
""
,
help
=
"Suffix for the names of saved samples."
,
)
group
.
add_argument
(
"--num_videos"
,
type
=
int
,
default
=
1
,
help
=
"Number of videos to generate for each prompt."
,
)
# ---sample size---
group
.
add_argument
(
"--num_frames"
,
type
=
int
,
default
=
204
,
help
=
"How many frames to sample from a video. "
,
)
group
.
add_argument
(
"--height"
,
type
=
int
,
default
=
768
,
help
=
"The height of video sample"
,
)
group
.
add_argument
(
"--width"
,
type
=
int
,
default
=
768
,
help
=
"The width of video sample"
,
)
# --- prompt ---
group
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
None
,
help
=
"Prompt for sampling during evaluation."
,
)
group
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1234
,
help
=
"Seed for evaluation."
)
# Classifier-Free Guidance
group
.
add_argument
(
"--pos_magic"
,
type
=
str
,
default
=
"超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。"
,
help
=
"Positive magic prompt for sampling."
)
group
.
add_argument
(
"--neg_magic"
,
type
=
str
,
default
=
"画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。"
,
help
=
"Negative magic prompt for sampling."
)
group
.
add_argument
(
"--cfg_scale"
,
type
=
float
,
default
=
9.0
,
help
=
"Classifier free guidance scale."
)
group
.
add_argument
(
"--mask_search_files_path"
,
type
=
str
,
default
=
"assets/mask_strategy.json"
)
group
.
add_argument
(
"--mask_strategy_file_path"
,
type
=
str
,
default
=
"assets/mask_strategy_stepvideo.json"
)
group
.
add_argument
(
"--skip_time_steps"
,
type
=
int
,
default
=
10
)
group
.
add_argument
(
"--mask_strategy_selected"
,
type
=
lambda
x
:
[
int
(
i
)
for
i
in
x
.
strip
(
'[]'
).
split
(
','
)],
# Convert string to list of integers
default
=
[
1
,
2
,
6
],
# Now can be directly set as a list
help
=
"order of candidates"
)
parser
.
add_argument
(
"--rel_l1_thresh"
,
type
=
float
,
default
=
0
,
help
=
"0.22 for 1.67x speedup, 0.23 for 2.1x speedup"
,
)
parser
.
add_argument
(
"--enable_teacache"
,
action
=
"store_true"
,
help
=
"Use teacache for speeding up inference"
,
)
return
parser
def
teacache_forward
(
self
,
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states_2
:
Optional
[
torch
.
Tensor
]
=
None
,
timestep
:
Optional
[
torch
.
LongTensor
]
=
None
,
added_cond_kwargs
:
Dict
[
str
,
torch
.
Tensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
fps
:
torch
.
Tensor
=
None
,
return_dict
:
bool
=
True
,
mask_strategy
=
None
,
):
assert
hidden_states
.
ndim
==
5
"hidden_states's shape should be (bsz, f, ch, h ,w)"
bsz
,
frame
,
_
,
height
,
width
=
hidden_states
.
shape
height
,
width
=
height
//
self
.
patch_size
,
width
//
self
.
patch_size
hidden_states
=
self
.
patchfy
(
hidden_states
)
len_frame
=
hidden_states
.
shape
[
1
]
if
self
.
use_additional_conditions
:
added_cond_kwargs
=
{
"resolution"
:
torch
.
tensor
([(
height
,
width
)]
*
bsz
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
),
"nframe"
:
torch
.
tensor
([
frame
]
*
bsz
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
),
"fps"
:
fps
}
else
:
added_cond_kwargs
=
{}
timestep
,
embedded_timestep
=
self
.
adaln_single
(
timestep
,
added_cond_kwargs
=
added_cond_kwargs
)
encoder_hidden_states
=
self
.
caption_projection
(
self
.
caption_norm
(
encoder_hidden_states
))
if
encoder_hidden_states_2
is
not
None
and
hasattr
(
self
,
'clip_projection'
):
clip_embedding
=
self
.
clip_projection
(
encoder_hidden_states_2
)
encoder_hidden_states
=
torch
.
cat
([
clip_embedding
,
encoder_hidden_states
],
dim
=
1
)
hidden_states
=
rearrange
(
hidden_states
,
'(b f) l d-> b (f l) d'
,
b
=
bsz
,
f
=
frame
,
l
=
len_frame
).
contiguous
()
embedded_timestep
=
repeat
(
embedded_timestep
,
'b d -> (b f) d'
,
f
=
frame
).
contiguous
()
shift
,
scale
=
(
self
.
scale_shift_table
[
None
]
+
embedded_timestep
[:,
None
]).
chunk
(
2
,
dim
=
1
)
encoder_hidden_states
,
attn_mask
=
self
.
prepare_attn_mask
(
encoder_attention_mask
,
encoder_hidden_states
,
q_seqlen
=
frame
*
len_frame
)
if
self
.
enable_teacache
:
hidden_states_
=
hidden_states
.
clone
()
normed_hidden_states
=
self
.
transformer_blocks
[
0
].
norm1
(
hidden_states_
)
normed_hidden_states
=
rearrange
(
normed_hidden_states
,
'b (f l) d -> (b f) l d'
,
b
=
bsz
,
f
=
frame
,
l
=
len_frame
)
modulated_inp
=
normed_hidden_states
*
(
1
+
scale
)
+
shift
if
self
.
cnt
==
0
or
self
.
cnt
==
self
.
num_steps
-
1
:
should_calc
=
True
self
.
accumulated_rel_l1_distance
=
0
else
:
coefficients
=
[
6.74352814e+03
,
-
2.22814115e+03
,
2.55029094e+02
,
-
1.12338285e+01
,
2.84921593e-01
]
rescale_func
=
np
.
poly1d
(
coefficients
)
self
.
accumulated_rel_l1_distance
+=
rescale_func
(
((
modulated_inp
-
self
.
previous_modulated_input
).
abs
().
mean
()
/
self
.
previous_modulated_input
.
abs
().
mean
()).
cpu
().
item
())
if
self
.
accumulated_rel_l1_distance
<
self
.
rel_l1_thresh
:
# print(f"accumulated_rel_l1_distance: {self.accumulated_rel_l1_distance}")
should_calc
=
False
else
:
# print(f"accumulated_rel_l1_distance: {self.accumulated_rel_l1_distance}")
should_calc
=
True
self
.
accumulated_rel_l1_distance
=
0
self
.
previous_modulated_input
=
modulated_inp
self
.
cnt
+=
1
if
self
.
cnt
==
self
.
num_steps
:
self
.
cnt
=
0
if
self
.
enable_teacache
:
if
not
should_calc
:
# print(f"skip step {self.cnt}")
hidden_states
+=
self
.
previous_residual
else
:
# print(f"calc step {self.cnt}")
ori_hidden_states
=
hidden_states
.
clone
()
hidden_states
=
self
.
block_forward
(
hidden_states
,
encoder_hidden_states
,
timestep
=
timestep
,
rope_positions
=
[
frame
,
height
,
width
],
attn_mask
=
attn_mask
,
parallel
=
self
.
parallel
,
mask_strategy
=
mask_strategy
)
self
.
previous_residual
=
hidden_states
-
ori_hidden_states
else
:
# --------------------- Pass through DiT blocks ------------------------
hidden_states
=
self
.
block_forward
(
hidden_states
,
encoder_hidden_states
,
timestep
=
timestep
,
rope_positions
=
[
frame
,
height
,
width
],
attn_mask
=
attn_mask
,
parallel
=
self
.
parallel
,
mask_strategy
=
mask_strategy
)
# ---------------------------- Final layer ------------------------------
hidden_states
=
rearrange
(
hidden_states
,
'b (f l) d -> (b f) l d'
,
b
=
bsz
,
f
=
frame
,
l
=
len_frame
)
hidden_states
=
self
.
norm_out
(
hidden_states
)
# Modulation
hidden_states
=
hidden_states
*
(
1
+
scale
)
+
shift
hidden_states
=
self
.
proj_out
(
hidden_states
)
# unpatchify
hidden_states
=
hidden_states
.
reshape
(
shape
=
(
-
1
,
height
,
width
,
self
.
patch_size
,
self
.
patch_size
,
self
.
out_channels
))
hidden_states
=
rearrange
(
hidden_states
,
'n h w p q c -> n c h p w q'
)
output
=
hidden_states
.
reshape
(
shape
=
(
-
1
,
self
.
out_channels
,
height
*
self
.
patch_size
,
width
*
self
.
patch_size
))
output
=
rearrange
(
output
,
'(b f) c h w -> b f c h w'
,
f
=
frame
)
if
return_dict
:
return
{
'x'
:
output
}
return
output
if
__name__
==
"__main__"
:
args
=
parse_args
()
initialize_distributed
()
main_print
(
f
"sequence parallel size:
{
nccl_info
.
sp_size
}
"
)
device
=
torch
.
cuda
.
current_device
()
setup_seed
(
args
.
seed
)
main_print
(
"Loading model, this might take a while..."
)
transformer
=
StepVideoModel
.
from_pretrained
(
os
.
path
.
join
(
args
.
model_dir
,
"transformer"
),
torch_dtype
=
torch
.
bfloat16
,
device_map
=
device
)
if
args
.
enable_teacache
:
transformer
.
forward
=
types
.
MethodType
(
teacache_forward
,
transformer
)
scheduler
=
FlowMatchDiscreteScheduler
()
pipeline
=
StepVideoPipeline
(
transformer
,
scheduler
,
save_path
=
args
.
save_path
)
pipeline
.
setup_api
(
vae_url
=
args
.
vae_url
,
caption_url
=
args
.
caption_url
,
)
# TeaCache
pipeline
.
transformer
.
__class__
.
enable_teacache
=
True
pipeline
.
transformer
.
__class__
.
cnt
=
0
pipeline
.
transformer
.
__class__
.
num_steps
=
args
.
infer_steps
pipeline
.
transformer
.
__class__
.
rel_l1_thresh
=
args
.
rel_l1_thresh
# 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
pipeline
.
transformer
.
__class__
.
accumulated_rel_l1_distance
=
0
pipeline
.
transformer
.
__class__
.
previous_modulated_input
=
None
pipeline
.
transformer
.
__class__
.
previous_residual
=
None
with
open
(
args
.
mask_strategy_file_path
,
'r'
)
as
f
:
mask_strategy
=
json
.
load
(
f
)
if
args
.
prompt
.
endswith
(
'.txt'
):
with
open
(
args
.
prompt
)
as
f
:
prompts
=
[
line
.
strip
()
for
line
in
f
.
readlines
()]
else
:
prompts
=
[
args
.
prompt
]
for
prompt
in
prompts
:
main_print
(
f
"Generating video for prompt:
{
prompt
}
"
)
videos
=
pipeline
(
prompt
=
prompt
,
num_frames
=
args
.
num_frames
,
height
=
args
.
height
,
width
=
args
.
width
,
num_inference_steps
=
args
.
infer_steps
,
guidance_scale
=
args
.
cfg_scale
,
time_shift
=
args
.
time_shift
,
pos_magic
=
args
.
pos_magic
,
neg_magic
=
args
.
neg_magic
,
output_file_name
=
prompt
[:
150
],
mask_strategy
=
mask_strategy
)
dist
.
destroy_process_group
()
FastVideo-main/fastvideo/train.py
0 → 100644
View file @
c07946d8
# !/bin/python3
# isort: skip_file
import
argparse
import
math
import
os
import
time
from
collections
import
deque
import
torch
import
torch.distributed
as
dist
import
wandb
from
accelerate.utils
import
set_seed
from
diffusers
import
FlowMatchEulerDiscreteScheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
,
convert_unet_state_dict_to_peft
from
peft
import
LoraConfig
,
set_peft_model_state_dict
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm.auto
import
tqdm
from
fastvideo.dataset.latent_datasets
import
(
LatentDataset
,
latent_collate_function
)
from
fastvideo.models.mochi_hf.mochi_latents_utils
import
normalize_dit_input
from
fastvideo.models.mochi_hf.pipeline_mochi
import
MochiPipeline
from
fastvideo.models.hunyuan_hf.pipeline_hunyuan
import
HunyuanVideoPipeline
from
fastvideo.utils.checkpoint
import
(
resume_lora_optimizer
,
save_checkpoint
,
save_lora_checkpoint
)
from
fastvideo.utils.communications
import
(
broadcast
,
sp_parallel_dataloader_wrapper
)
from
fastvideo.utils.dataset_utils
import
LengthGroupedSampler
from
fastvideo.utils.fsdp_util
import
(
apply_fsdp_checkpointing
,
get_dit_fsdp_kwargs
)
from
fastvideo.utils.load
import
load_transformer
from
fastvideo.utils.logging_
import
main_print
from
fastvideo.utils.parallel_states
import
(
destroy_sequence_parallel_group
,
get_sequence_parallel_state
,
initialize_sequence_parallel_state
)
from
fastvideo.utils.validation
import
log_validation
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version
(
"0.31.0"
)
torch
.
_dynamo
.
config
.
capture_scalar_outputs
=
True
def
compute_density_for_timestep_sampling
(
weighting_scheme
:
str
,
batch_size
:
int
,
generator
,
logit_mean
:
float
=
None
,
logit_std
:
float
=
None
,
mode_scale
:
float
=
None
,
):
"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if
weighting_scheme
==
"logit_normal"
:
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u
=
torch
.
normal
(
mean
=
logit_mean
,
std
=
logit_std
,
size
=
(
batch_size
,
),
device
=
"cpu"
,
generator
=
generator
,
)
u
=
torch
.
nn
.
functional
.
sigmoid
(
u
)
elif
weighting_scheme
==
"mode"
:
u
=
torch
.
rand
(
size
=
(
batch_size
,
),
device
=
"cpu"
,
generator
=
generator
)
u
=
1
-
u
-
mode_scale
*
(
torch
.
cos
(
math
.
pi
*
u
/
2
)
**
2
-
1
+
u
)
else
:
u
=
torch
.
rand
(
size
=
(
batch_size
,
),
device
=
"cpu"
,
generator
=
generator
)
return
u
def
get_sigmas
(
noise_scheduler
,
device
,
timesteps
,
n_dim
=
4
,
dtype
=
torch
.
float32
):
sigmas
=
noise_scheduler
.
sigmas
.
to
(
device
=
device
,
dtype
=
dtype
)
schedule_timesteps
=
noise_scheduler
.
timesteps
.
to
(
device
)
timesteps
=
timesteps
.
to
(
device
)
# print("timesteps:",timesteps)
# print("schedule_timesteps:",schedule_timesteps)
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
n_dim
:
sigma
=
sigma
.
unsqueeze
(
-
1
)
return
sigma
def
train_one_step
(
transformer
,
model_type
,
optimizer
,
lr_scheduler
,
loader
,
noise_scheduler
,
noise_random_generator
,
gradient_accumulation_steps
,
sp_size
,
precondition_outputs
,
max_grad_norm
,
weighting_scheme
,
logit_mean
,
logit_std
,
mode_scale
,
):
total_loss
=
0.0
optimizer
.
zero_grad
()
for
_
in
range
(
gradient_accumulation_steps
):
(
latents
,
encoder_hidden_states
,
latents_attention_mask
,
encoder_attention_mask
,
)
=
next
(
loader
)
latents
=
normalize_dit_input
(
model_type
,
latents
)
batch_size
=
latents
.
shape
[
0
]
noise
=
torch
.
randn_like
(
latents
)
u
=
compute_density_for_timestep_sampling
(
weighting_scheme
=
weighting_scheme
,
batch_size
=
batch_size
,
generator
=
noise_random_generator
,
logit_mean
=
logit_mean
,
logit_std
=
logit_std
,
mode_scale
=
mode_scale
,
)
indices
=
(
u
*
noise_scheduler
.
config
.
num_train_timesteps
).
long
()
timesteps
=
noise_scheduler
.
timesteps
[
indices
].
to
(
device
=
latents
.
device
)
if
sp_size
>
1
:
# Make sure that the timesteps are the same across all sp processes.
broadcast
(
timesteps
)
sigmas
=
get_sigmas
(
noise_scheduler
,
latents
.
device
,
timesteps
,
n_dim
=
latents
.
ndim
,
dtype
=
latents
.
dtype
,
)
noisy_model_input
=
(
1.0
-
sigmas
)
*
latents
+
sigmas
*
noise
with
torch
.
autocast
(
"cuda"
,
dtype
=
torch
.
bfloat16
):
input_kwargs
=
{
"hidden_states"
:
noisy_model_input
,
"encoder_hidden_states"
:
encoder_hidden_states
,
"timestep"
:
timesteps
,
"encoder_attention_mask"
:
encoder_attention_mask
,
# B, L
"return_dict"
:
False
,
}
if
'hunyuan'
in
model_type
:
input_kwargs
[
"guidance"
]
=
torch
.
tensor
([
1000.0
],
device
=
noisy_model_input
.
device
,
dtype
=
torch
.
bfloat16
)
model_pred
=
transformer
(
**
input_kwargs
)[
0
]
if
precondition_outputs
:
model_pred
=
noisy_model_input
-
model_pred
*
sigmas
if
precondition_outputs
:
target
=
latents
else
:
target
=
noise
-
latents
loss
=
(
torch
.
mean
((
model_pred
.
float
()
-
target
.
float
())
**
2
)
/
gradient_accumulation_steps
)
loss
.
backward
()
avg_loss
=
loss
.
detach
().
clone
()
dist
.
all_reduce
(
avg_loss
,
op
=
dist
.
ReduceOp
.
AVG
)
total_loss
+=
avg_loss
.
item
()
grad_norm
=
transformer
.
clip_grad_norm_
(
max_grad_norm
)
optimizer
.
step
()
lr_scheduler
.
step
()
return
total_loss
,
grad_norm
.
item
()
def
main
(
args
):
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
local_rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
rank
=
int
(
os
.
environ
[
"RANK"
])
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
dist
.
init_process_group
(
"nccl"
)
torch
.
cuda
.
set_device
(
local_rank
)
device
=
torch
.
cuda
.
current_device
()
initialize_sequence_parallel_state
(
args
.
sp_size
)
# If passed along, set the training seed now. On GPU...
if
args
.
seed
is
not
None
:
# TODO: t within the same seq parallel group should be the same. Noise should be different.
set_seed
(
args
.
seed
+
rank
)
# We use different seeds for the noise generation in each process to ensure that the noise is different in a batch.
noise_random_generator
=
None
# Handle the repository creation
if
rank
<=
0
and
args
.
output_dir
is
not
None
:
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# For mixed precision training we cast all non-trainable weights to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
# Create model:
main_print
(
f
"--> loading model from
{
args
.
pretrained_model_name_or_path
}
"
)
# keep the master weight to float32
print
(
"<"
*
50
)
transformer
=
load_transformer
(
args
.
model_type
,
args
.
dit_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
torch
.
float32
if
args
.
master_weight_type
==
"fp32"
else
torch
.
bfloat16
,
)
print
(
">"
*
50
)
if
args
.
use_lora
:
assert
args
.
model_type
!=
"hunyuan"
,
"LoRA is only supported for huggingface model. Please use hunyuan_hf for lora finetuning"
if
args
.
model_type
==
"mochi"
:
pipe
=
MochiPipeline
elif
args
.
model_type
==
"hunyuan_hf"
:
pipe
=
HunyuanVideoPipeline
transformer
.
requires_grad_
(
False
)
transformer_lora_config
=
LoraConfig
(
r
=
args
.
lora_rank
,
lora_alpha
=
args
.
lora_alpha
,
init_lora_weights
=
True
,
target_modules
=
[
"to_k"
,
"to_q"
,
"to_v"
,
"to_out.0"
],
)
transformer
.
add_adapter
(
transformer_lora_config
)
if
args
.
resume_from_lora_checkpoint
:
lora_state_dict
=
pipe
.
lora_state_dict
(
args
.
resume_from_lora_checkpoint
)
transformer_state_dict
=
{
f
'
{
k
.
replace
(
"transformer."
,
""
)
}
'
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"transformer."
)
}
transformer_state_dict
=
convert_unet_state_dict_to_peft
(
transformer_state_dict
)
incompatible_keys
=
set_peft_model_state_dict
(
transformer
,
transformer_state_dict
,
adapter_name
=
"default"
)
if
incompatible_keys
is
not
None
:
# check only for unexpected keys
unexpected_keys
=
getattr
(
incompatible_keys
,
"unexpected_keys"
,
None
)
if
unexpected_keys
:
main_print
(
f
"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f
"
{
unexpected_keys
}
. "
)
main_print
(
f
" Total training parameters =
{
sum
(
p
.
numel
()
for
p
in
transformer
.
parameters
()
if
p
.
requires_grad
)
/
1e6
}
M"
)
main_print
(
f
"--> Initializing FSDP with sharding strategy:
{
args
.
fsdp_sharding_startegy
}
"
)
fsdp_kwargs
,
no_split_modules
=
get_dit_fsdp_kwargs
(
transformer
,
args
.
fsdp_sharding_startegy
,
args
.
use_lora
,
args
.
use_cpu_offload
,
args
.
master_weight_type
,
)
if
args
.
use_lora
:
transformer
.
config
.
lora_rank
=
args
.
lora_rank
transformer
.
config
.
lora_alpha
=
args
.
lora_alpha
transformer
.
config
.
lora_target_modules
=
[
"to_k"
,
"to_q"
,
"to_v"
,
"to_out.0"
]
transformer
.
_no_split_modules
=
[
no_split_module
.
__name__
for
no_split_module
in
no_split_modules
]
fsdp_kwargs
[
"auto_wrap_policy"
]
=
fsdp_kwargs
[
"auto_wrap_policy"
](
transformer
)
fsdp_kwargs
[
'use_orig_params'
]
=
True
transformer
=
FSDP
(
transformer
,
**
fsdp_kwargs
,
)
# transformer = torch.compile(transformer)
main_print
(
"--> model loaded"
)
if
args
.
gradient_checkpointing
:
apply_fsdp_checkpointing
(
transformer
,
no_split_modules
,
args
.
selective_checkpointing
)
main_print
(
transformer
)
# Set model as trainable.
transformer
.
train
()
noise_scheduler
=
FlowMatchEulerDiscreteScheduler
()
params_to_optimize
=
transformer
.
parameters
()
params_to_optimize
=
list
(
filter
(
lambda
p
:
p
.
requires_grad
,
params_to_optimize
))
optimizer
=
torch
.
optim
.
AdamW
(
params_to_optimize
,
lr
=
args
.
learning_rate
,
betas
=
(
0.9
,
0.999
),
weight_decay
=
args
.
weight_decay
,
eps
=
1e-8
,
)
init_steps
=
0
if
args
.
resume_from_lora_checkpoint
:
transformer
,
optimizer
,
init_steps
=
resume_lora_optimizer
(
transformer
,
args
.
resume_from_lora_checkpoint
,
optimizer
)
main_print
(
f
"optimizer:
{
optimizer
}
"
)
lr_scheduler
=
get_scheduler
(
args
.
lr_scheduler
,
optimizer
=
optimizer
,
num_warmup_steps
=
args
.
lr_warmup_steps
,
num_training_steps
=
args
.
max_train_steps
,
num_cycles
=
args
.
lr_num_cycles
,
power
=
args
.
lr_power
,
last_epoch
=
init_steps
-
1
,
)
train_dataset
=
LatentDataset
(
args
.
data_json_path
,
args
.
num_latent_t
,
args
.
cfg
)
sampler
=
(
LengthGroupedSampler
(
args
.
train_batch_size
,
rank
=
rank
,
world_size
=
world_size
,
lengths
=
train_dataset
.
lengths
,
group_frame
=
args
.
group_frame
,
group_resolution
=
args
.
group_resolution
,
)
if
(
args
.
group_frame
or
args
.
group_resolution
)
else
DistributedSampler
(
train_dataset
,
rank
=
rank
,
num_replicas
=
world_size
,
shuffle
=
False
))
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
sampler
,
collate_fn
=
latent_collate_function
,
pin_memory
=
True
,
batch_size
=
args
.
train_batch_size
,
num_workers
=
args
.
dataloader_num_workers
,
drop_last
=
True
,
)
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
)
/
args
.
gradient_accumulation_steps
*
args
.
sp_size
/
args
.
train_sp_batch_size
)
args
.
num_train_epochs
=
math
.
ceil
(
args
.
max_train_steps
/
num_update_steps_per_epoch
)
# if rank <= 0:
# project = args.tracker_project_name or "fastvideo"
# wandb.init(project=project, config=args)
# Train!
total_batch_size
=
(
world_size
*
args
.
gradient_accumulation_steps
/
args
.
sp_size
*
args
.
train_sp_batch_size
)
main_print
(
"***** Running training *****"
)
main_print
(
f
" Num examples =
{
len
(
train_dataset
)
}
"
)
main_print
(
f
" Dataloader size =
{
len
(
train_dataloader
)
}
"
)
main_print
(
f
" Num Epochs =
{
args
.
num_train_epochs
}
"
)
main_print
(
f
" Resume training from step
{
init_steps
}
"
)
main_print
(
f
" Instantaneous batch size per device =
{
args
.
train_batch_size
}
"
)
main_print
(
f
" Total train batch size (w. data & sequence parallel, accumulation) =
{
total_batch_size
}
"
)
main_print
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
main_print
(
f
" Total optimization steps =
{
args
.
max_train_steps
}
"
)
main_print
(
f
" Total training parameters per FSDP shard =
{
sum
(
p
.
numel
()
for
p
in
transformer
.
parameters
()
if
p
.
requires_grad
)
/
1e9
}
B"
)
# print dtype
main_print
(
f
" Master weight dtype:
{
transformer
.
parameters
().
__next__
().
dtype
}
"
)
# Potentially load in the weights and states from a previous save
if
args
.
resume_from_checkpoint
:
assert
NotImplementedError
(
"resume_from_checkpoint is not supported now."
)
# TODO
progress_bar
=
tqdm
(
range
(
0
,
args
.
max_train_steps
),
initial
=
init_steps
,
desc
=
"Steps"
,
# Only show the progress bar once on each machine.
disable
=
local_rank
>
0
,
)
loader
=
sp_parallel_dataloader_wrapper
(
train_dataloader
,
device
,
args
.
train_batch_size
,
args
.
sp_size
,
args
.
train_sp_batch_size
,
)
step_times
=
deque
(
maxlen
=
100
)
# todo future
for
i
in
range
(
init_steps
):
next
(
loader
)
for
step
in
range
(
init_steps
+
1
,
args
.
max_train_steps
+
1
):
start_time
=
time
.
time
()
#if False:
if
step
==
100
:
from
torch.profiler
import
profile
,
record_function
,
ProfilerActivity
with
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,],
record_shapes
=
True
,
profile_memory
=
False
,
with_stack
=
False
)
as
prof
:
loss
,
grad_norm
=
train_one_step
(
transformer
,
args
.
model_type
,
optimizer
,
lr_scheduler
,
loader
,
noise_scheduler
,
noise_random_generator
,
args
.
gradient_accumulation_steps
,
args
.
sp_size
,
args
.
precondition_outputs
,
args
.
max_grad_norm
,
args
.
weighting_scheme
,
args
.
logit_mean
,
args
.
logit_std
,
args
.
mode_scale
,
)
print
(
prof
.
key_averages
().
table
(
sort_by
=
"self_cuda_time_total"
))
prof
.
export_chrome_trace
(
f
"/public/home/wuxk/code/modelzoo/FastVideo-main/scripts/finetune/prof/bw_fv_trace_ge_
{
dist
.
get_rank
()
}
.json"
)
# torch.cuda.synchronize()
else
:
loss
,
grad_norm
=
train_one_step
(
transformer
,
args
.
model_type
,
optimizer
,
lr_scheduler
,
loader
,
noise_scheduler
,
noise_random_generator
,
args
.
gradient_accumulation_steps
,
args
.
sp_size
,
args
.
precondition_outputs
,
args
.
max_grad_norm
,
args
.
weighting_scheme
,
args
.
logit_mean
,
args
.
logit_std
,
args
.
mode_scale
,
)
step_time
=
time
.
time
()
-
start_time
step_times
.
append
(
step_time
)
avg_step_time
=
sum
(
step_times
)
/
len
(
step_times
)
progress_bar
.
set_postfix
({
"loss"
:
f
"
{
loss
:.
4
f
}
"
,
"step_time"
:
f
"
{
step_time
:.
2
f
}
s"
,
"grad_norm"
:
grad_norm
,
})
progress_bar
.
update
(
1
)
# if rank <= 0:
# wandb.log(
# {
# "train_loss": loss,
# "learning_rate": lr_scheduler.get_last_lr()[0],
# "step_time": step_time,
# "avg_step_time": avg_step_time,
# "grad_norm": grad_norm,
# },
# step=step,
# )
main_print
(
f
"zll step_time:
{
step_time
:.
2
f
}
s avg_step_time:
{
sum
(
step_times
)
/
len
(
step_times
)
}
"
)
if
step
%
args
.
checkpointing_steps
==
0
:
if
args
.
use_lora
:
# Save LoRA weights
save_lora_checkpoint
(
transformer
,
optimizer
,
rank
,
args
.
output_dir
,
step
,
pipe
)
else
:
# Your existing checkpoint saving code
save_checkpoint
(
transformer
,
rank
,
args
.
output_dir
,
step
)
dist
.
barrier
()
if
args
.
log_validation
and
step
%
args
.
validation_steps
==
0
:
log_validation
(
args
,
transformer
,
device
,
torch
.
bfloat16
,
step
,
shift
=
args
.
shift
)
if
args
.
use_lora
:
save_lora_checkpoint
(
transformer
,
optimizer
,
rank
,
args
.
output_dir
,
args
.
max_train_steps
,
pipe
)
else
:
save_checkpoint
(
transformer
,
rank
,
args
.
output_dir
,
args
.
max_train_steps
)
if
get_sequence_parallel_state
():
destroy_sequence_parallel_group
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_type"
,
type
=
str
,
default
=
"mochi"
,
help
=
"The type of model to train. Currentlt support [mochi, hunyuan_hf, hunyuan]"
)
# dataset & dataloader
parser
.
add_argument
(
"--data_json_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--num_height"
,
type
=
int
,
default
=
480
)
parser
.
add_argument
(
"--num_width"
,
type
=
int
,
default
=
848
)
parser
.
add_argument
(
"--num_frames"
,
type
=
int
,
default
=
163
)
parser
.
add_argument
(
"--dataloader_num_workers"
,
type
=
int
,
default
=
10
,
help
=
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
,
)
parser
.
add_argument
(
"--train_batch_size"
,
type
=
int
,
default
=
16
,
help
=
"Batch size (per device) for the training dataloader."
,
)
parser
.
add_argument
(
"--num_latent_t"
,
type
=
int
,
default
=
28
,
help
=
"Number of latent timesteps."
)
parser
.
add_argument
(
"--group_frame"
,
action
=
"store_true"
)
# TODO
parser
.
add_argument
(
"--group_resolution"
,
action
=
"store_true"
)
# TODO
# text encoder & vae & diffusion model
parser
.
add_argument
(
"--pretrained_model_name_or_path"
,
type
=
str
)
parser
.
add_argument
(
"--dit_model_name_or_path"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--cache_dir"
,
type
=
str
,
default
=
"./cache_dir"
)
# diffusion setting
parser
.
add_argument
(
"--ema_decay"
,
type
=
float
,
default
=
0.999
)
parser
.
add_argument
(
"--ema_start_step"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--cfg"
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
"--precondition_outputs"
,
action
=
"store_true"
,
help
=
"Whether to precondition the outputs of the model."
,
)
# validation & logs
parser
.
add_argument
(
"--validation_prompt_dir"
,
type
=
str
)
parser
.
add_argument
(
"--uncond_prompt_dir"
,
type
=
str
)
parser
.
add_argument
(
"--validation_sampling_steps"
,
type
=
str
,
default
=
"64"
,
help
=
"use ',' to split multi sampling steps"
,
)
parser
.
add_argument
(
"--validation_guidance_scale"
,
type
=
str
,
default
=
"4.5"
,
help
=
"use ',' to split multi scale"
,
)
parser
.
add_argument
(
"--validation_steps"
,
type
=
int
,
default
=
50
)
parser
.
add_argument
(
"--log_validation"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--tracker_project_name"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
None
,
help
=
"A seed for reproducible training."
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
None
,
help
=
"The output directory where the model predictions and checkpoints will be written."
,
)
parser
.
add_argument
(
"--checkpoints_total_limit"
,
type
=
int
,
default
=
None
,
help
=
(
"Max number of checkpoints to store."
),
)
parser
.
add_argument
(
"--checkpointing_steps"
,
type
=
int
,
default
=
500
,
help
=
(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser
.
add_argument
(
"--shift"
,
type
=
float
,
default
=
1.0
,
help
=
(
"Set shift to 7 for hunyuan model."
))
parser
.
add_argument
(
"--resume_from_checkpoint"
,
type
=
str
,
default
=
None
,
help
=
(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser
.
add_argument
(
"--resume_from_lora_checkpoint"
,
type
=
str
,
default
=
None
,
help
=
(
"Whether training should be resumed from a previous lora checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser
.
add_argument
(
"--logging_dir"
,
type
=
str
,
default
=
"logs"
,
help
=
(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
# optimizer & scheduler & Training
parser
.
add_argument
(
"--num_train_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--max_train_steps"
,
type
=
int
,
default
=
None
,
help
=
"Total number of training steps to perform. If provided, overrides num_train_epochs."
,
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
1e-4
,
help
=
"Initial learning rate (after the potential warmup period) to use."
,
)
parser
.
add_argument
(
"--scale_lr"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size."
,
)
parser
.
add_argument
(
"--lr_warmup_steps"
,
type
=
int
,
default
=
10
,
help
=
"Number of steps for the warmup in the lr scheduler."
,
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--gradient_checkpointing"
,
action
=
"store_true"
,
help
=
"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass."
,
)
parser
.
add_argument
(
"--selective_checkpointing"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--allow_tf32"
,
action
=
"store_true"
,
help
=
(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser
.
add_argument
(
"--mixed_precision"
,
type
=
str
,
default
=
None
,
choices
=
[
"no"
,
"fp16"
,
"bf16"
],
help
=
(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser
.
add_argument
(
"--use_cpu_offload"
,
action
=
"store_true"
,
help
=
"Whether to use CPU offload for param & gradient & optimizer states."
,
)
parser
.
add_argument
(
"--sp_size"
,
type
=
int
,
default
=
1
,
help
=
"For sequence parallel"
)
parser
.
add_argument
(
"--train_sp_batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size for sequence parallel training"
,
)
parser
.
add_argument
(
"--use_lora"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use LoRA for finetuning."
,
)
parser
.
add_argument
(
"--lora_alpha"
,
type
=
int
,
default
=
256
,
help
=
"Alpha parameter for LoRA."
)
parser
.
add_argument
(
"--lora_rank"
,
type
=
int
,
default
=
128
,
help
=
"LoRA rank parameter. "
)
parser
.
add_argument
(
"--fsdp_sharding_startegy"
,
default
=
"full"
)
parser
.
add_argument
(
"--weighting_scheme"
,
type
=
str
,
default
=
"uniform"
,
choices
=
[
"sigma_sqrt"
,
"logit_normal"
,
"mode"
,
"cosmap"
,
"uniform"
],
)
parser
.
add_argument
(
"--logit_mean"
,
type
=
float
,
default
=
0.0
,
help
=
"mean to use when using the `'logit_normal'` weighting scheme."
,
)
parser
.
add_argument
(
"--logit_std"
,
type
=
float
,
default
=
1.0
,
help
=
"std to use when using the `'logit_normal'` weighting scheme."
,
)
parser
.
add_argument
(
"--mode_scale"
,
type
=
float
,
default
=
1.29
,
help
=
"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`."
,
)
# lr_scheduler
parser
.
add_argument
(
"--lr_scheduler"
,
type
=
str
,
default
=
"constant"
,
help
=
(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser
.
add_argument
(
"--lr_num_cycles"
,
type
=
int
,
default
=
1
,
help
=
"Number of cycles in the learning rate scheduler."
,
)
parser
.
add_argument
(
"--lr_power"
,
type
=
float
,
default
=
1.0
,
help
=
"Power factor of the polynomial scheduler."
,
)
parser
.
add_argument
(
"--weight_decay"
,
type
=
float
,
default
=
0.01
,
help
=
"Weight decay to apply."
)
parser
.
add_argument
(
"--master_weight_type"
,
type
=
str
,
default
=
"fp32"
,
help
=
"Weight type to use - fp32 or bf16."
,
)
args
=
parser
.
parse_args
()
main
(
args
)
FastVideo-main/fastvideo/train.py-bak
0 → 100644
View file @
c07946d8
# !/bin/python3
# isort: skip_file
import argparse
import math
import os
import time
from collections import deque
import torch
import torch.distributed as dist
import wandb
from accelerate.utils import set_seed
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft
from peft import LoraConfig, set_peft_model_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
from fastvideo.dataset.latent_datasets import (LatentDataset, latent_collate_function)
from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input
from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline
from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline
from fastvideo.utils.checkpoint import (resume_lora_optimizer, save_checkpoint, save_lora_checkpoint)
from fastvideo.utils.communications import (broadcast, sp_parallel_dataloader_wrapper)
from fastvideo.utils.dataset_utils import LengthGroupedSampler
from fastvideo.utils.fsdp_util import (apply_fsdp_checkpointing, get_dit_fsdp_kwargs)
from fastvideo.utils.load import load_transformer
from fastvideo.utils.logging_ import main_print
from fastvideo.utils.parallel_states import (destroy_sequence_parallel_group, get_sequence_parallel_state,
initialize_sequence_parallel_state)
from fastvideo.utils.validation import log_validation
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0")
def compute_density_for_timestep_sampling(
weighting_scheme: str,
batch_size: int,
generator,
logit_mean: float = None,
logit_std: float = None,
mode_scale: float = None,
):
"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(
mean=logit_mean,
std=logit_std,
size=(batch_size, ),
device="cpu",
generator=generator,
)
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size, ), device="cpu", generator=generator)
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2)**2 - 1 + u)
else:
u = torch.rand(size=(batch_size, ), device="cpu", generator=generator)
return u
def get_sigmas(noise_scheduler, device, timesteps, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def train_one_step(
transformer,
model_type,
optimizer,
lr_scheduler,
loader,
noise_scheduler,
noise_random_generator,
gradient_accumulation_steps,
sp_size,
precondition_outputs,
max_grad_norm,
weighting_scheme,
logit_mean,
logit_std,
mode_scale,
):
total_loss = 0.0
optimizer.zero_grad()
for _ in range(gradient_accumulation_steps):
(
latents,
encoder_hidden_states,
latents_attention_mask,
encoder_attention_mask,
) = next(loader)
latents = normalize_dit_input(model_type, latents)
batch_size = latents.shape[0]
noise = torch.randn_like(latents)
u = compute_density_for_timestep_sampling(
weighting_scheme=weighting_scheme,
batch_size=batch_size,
generator=noise_random_generator,
logit_mean=logit_mean,
logit_std=logit_std,
mode_scale=mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=latents.device)
if sp_size > 1:
# Make sure that the timesteps are the same across all sp processes.
broadcast(timesteps)
sigmas = get_sigmas(
noise_scheduler,
latents.device,
timesteps,
n_dim=latents.ndim,
dtype=latents.dtype,
)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
with torch.autocast("cuda", dtype=torch.bfloat16):
input_kwargs = {
"hidden_states": noisy_model_input,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timesteps,
"encoder_attention_mask": encoder_attention_mask, # B, L
"return_dict": False,
}
if 'hunyuan' in model_type:
input_kwargs["guidance"] = torch.tensor([1000.0], device=noisy_model_input.device, dtype=torch.bfloat16)
model_pred = transformer(**input_kwargs)[0]
if precondition_outputs:
model_pred = noisy_model_input - model_pred * sigmas
if precondition_outputs:
target = latents
else:
target = noise - latents
loss = (torch.mean((model_pred.float() - target.float())**2) / gradient_accumulation_steps)
loss.backward()
avg_loss = loss.detach().clone()
dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
total_loss += avg_loss.item()
grad_norm = transformer.clip_grad_norm_(max_grad_norm)
optimizer.step()
lr_scheduler.step()
return total_loss, grad_norm.item()
def main(args):
torch.backends.cuda.matmul.allow_tf32 = True
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl")
torch.cuda.set_device(local_rank)
device = torch.cuda.current_device()
initialize_sequence_parallel_state(args.sp_size)
# If passed along, set the training seed now. On GPU...
if args.seed is not None:
# TODO: t within the same seq parallel group should be the same. Noise should be different.
set_seed(args.seed + rank)
# We use different seeds for the noise generation in each process to ensure that the noise is different in a batch.
noise_random_generator = None
# Handle the repository creation
if rank <= 0 and args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# For mixed precision training we cast all non-trainable weights to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
# Create model:
main_print(f"--> loading model from {args.pretrained_model_name_or_path}")
# keep the master weight to float32
print("<"*50)
transformer = load_transformer(
args.model_type,
args.dit_model_name_or_path,
args.pretrained_model_name_or_path,
torch.float32 if args.master_weight_type == "fp32" else torch.bfloat16,
)
print(">"*50)
if args.use_lora:
assert args.model_type != "hunyuan", "LoRA is only supported for huggingface model. Please use hunyuan_hf for lora finetuning"
if args.model_type == "mochi":
pipe = MochiPipeline
elif args.model_type == "hunyuan_hf":
pipe = HunyuanVideoPipeline
transformer.requires_grad_(False)
transformer_lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
init_lora_weights=True,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)
if args.resume_from_lora_checkpoint:
lora_state_dict = pipe.lora_state_dict(args.resume_from_lora_checkpoint)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v
for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer, transformer_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
main_print(f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. ")
main_print(
f" Total training parameters = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e6} M")
main_print(f"--> Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy}")
fsdp_kwargs, no_split_modules = get_dit_fsdp_kwargs(
transformer,
args.fsdp_sharding_startegy,
args.use_lora,
args.use_cpu_offload,
args.master_weight_type,
)
if args.use_lora:
transformer.config.lora_rank = args.lora_rank
transformer.config.lora_alpha = args.lora_alpha
transformer.config.lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
transformer._no_split_modules = [no_split_module.__name__ for no_split_module in no_split_modules]
fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"](transformer)
transformer = FSDP(
transformer,
**fsdp_kwargs,
)
main_print("--> model loaded")
if args.gradient_checkpointing:
apply_fsdp_checkpointing(transformer, no_split_modules, args.selective_checkpointing)
# Set model as trainable.
transformer.train()
noise_scheduler = FlowMatchEulerDiscreteScheduler()
params_to_optimize = transformer.parameters()
params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=args.learning_rate,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
eps=1e-8,
)
init_steps = 0
if args.resume_from_lora_checkpoint:
transformer, optimizer, init_steps = resume_lora_optimizer(transformer, args.resume_from_lora_checkpoint,
optimizer)
main_print(f"optimizer: {optimizer}")
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
last_epoch=init_steps - 1,
)
train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg)
sampler = (LengthGroupedSampler(
args.train_batch_size,
rank=rank,
world_size=world_size,
lengths=train_dataset.lengths,
group_frame=args.group_frame,
group_resolution=args.group_resolution,
) if (args.group_frame or args.group_resolution) else DistributedSampler(
train_dataset, rank=rank, num_replicas=world_size, shuffle=False))
train_dataloader = DataLoader(
train_dataset,
sampler=sampler,
collate_fn=latent_collate_function,
pin_memory=True,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
drop_last=True,
)
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps * args.sp_size / args.train_sp_batch_size)
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# if rank <= 0:
# project = args.tracker_project_name or "fastvideo"
# wandb.init(project=project, config=args)
# Train!
total_batch_size = (world_size * args.gradient_accumulation_steps / args.sp_size * args.train_sp_batch_size)
main_print("***** Running training *****")
main_print(f" Num examples = {len(train_dataset)}")
main_print(f" Dataloader size = {len(train_dataloader)}")
main_print(f" Num Epochs = {args.num_train_epochs}")
main_print(f" Resume training from step {init_steps}")
main_print(f" Instantaneous batch size per device = {args.train_batch_size}")
main_print(f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}")
main_print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
main_print(f" Total optimization steps = {args.max_train_steps}")
main_print(
f" Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B"
)
# print dtype
main_print(f" Master weight dtype: {transformer.parameters().__next__().dtype}")
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
assert NotImplementedError("resume_from_checkpoint is not supported now.")
# TODO
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=init_steps,
desc="Steps",
# Only show the progress bar once on each machine.
disable=local_rank > 0,
)
loader = sp_parallel_dataloader_wrapper(
train_dataloader,
device,
args.train_batch_size,
args.sp_size,
args.train_sp_batch_size,
)
step_times = deque(maxlen=100)
# todo future
for i in range(init_steps):
next(loader)
for step in range(init_steps + 1, args.max_train_steps + 1):
start_time = time.time()
loss, grad_norm = train_one_step(
transformer,
args.model_type,
optimizer,
lr_scheduler,
loader,
noise_scheduler,
noise_random_generator,
args.gradient_accumulation_steps,
args.sp_size,
args.precondition_outputs,
args.max_grad_norm,
args.weighting_scheme,
args.logit_mean,
args.logit_std,
args.mode_scale,
)
step_time = time.time() - start_time
step_times.append(step_time)
avg_step_time = sum(step_times) / len(step_times)
progress_bar.set_postfix({
"loss": f"{loss:.4f}",
"step_time": f"{step_time:.2f}s",
"grad_norm": grad_norm,
})
progress_bar.update(1)
# if rank <= 0:
# wandb.log(
# {
# "train_loss": loss,
# "learning_rate": lr_scheduler.get_last_lr()[0],
# "step_time": step_time,
# "avg_step_time": avg_step_time,
# "grad_norm": grad_norm,
# },
# step=step,
# )
if step % args.checkpointing_steps == 0:
if args.use_lora:
# Save LoRA weights
save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, step, pipe)
else:
# Your existing checkpoint saving code
save_checkpoint(transformer, rank, args.output_dir, step)
dist.barrier()
if args.log_validation and step % args.validation_steps == 0:
log_validation(args, transformer, device, torch.bfloat16, step, shift=args.shift)
if args.use_lora:
save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, args.max_train_steps, pipe)
else:
save_checkpoint(transformer, rank, args.output_dir, args.max_train_steps)
if get_sequence_parallel_state():
destroy_sequence_parallel_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_type",
type=str,
default="mochi",
help="The type of model to train. Currentlt support [mochi, hunyuan_hf, hunyuan]")
# dataset & dataloader
parser.add_argument("--data_json_path", type=str, required=True)
parser.add_argument("--num_height", type=int, default=480)
parser.add_argument("--num_width", type=int, default=848)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=10,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=16,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument("--num_latent_t", type=int, default=28, help="Number of latent timesteps.")
parser.add_argument("--group_frame", action="store_true") # TODO
parser.add_argument("--group_resolution", action="store_true") # TODO
# text encoder & vae & diffusion model
parser.add_argument("--pretrained_model_name_or_path", type=str)
parser.add_argument("--dit_model_name_or_path", type=str, default=None)
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
# diffusion setting
parser.add_argument("--ema_decay", type=float, default=0.999)
parser.add_argument("--ema_start_step", type=int, default=0)
parser.add_argument("--cfg", type=float, default=0.1)
parser.add_argument(
"--precondition_outputs",
action="store_true",
help="Whether to precondition the outputs of the model.",
)
# validation & logs
parser.add_argument("--validation_prompt_dir", type=str)
parser.add_argument("--uncond_prompt_dir", type=str)
parser.add_argument(
"--validation_sampling_steps",
type=str,
default="64",
help="use ',' to split multi sampling steps",
)
parser.add_argument(
"--validation_guidance_scale",
type=str,
default="4.5",
help="use ',' to split multi scale",
)
parser.add_argument("--validation_steps", type=int, default=50)
parser.add_argument("--log_validation", action="store_true")
parser.add_argument("--tracker_project_name", type=str, default=None)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=("Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."),
)
parser.add_argument("--shift", type=float, default=1.0, help=("Set shift to 7 for hunyuan model."))
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--resume_from_lora_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous lora checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
)
# optimizer & scheduler & Training
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
default=10,
help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument("--selective_checkpointing", type=float, default=1.0)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
)
parser.add_argument(
"--use_cpu_offload",
action="store_true",
help="Whether to use CPU offload for param & gradient & optimizer states.",
)
parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel")
parser.add_argument(
"--train_sp_batch_size",
type=int,
default=1,
help="Batch size for sequence parallel training",
)
parser.add_argument(
"--use_lora",
action="store_true",
default=False,
help="Whether to use LoRA for finetuning.",
)
parser.add_argument("--lora_alpha", type=int, default=256, help="Alpha parameter for LoRA.")
parser.add_argument("--lora_rank", type=int, default=128, help="LoRA rank parameter. ")
parser.add_argument("--fsdp_sharding_startegy", default="full")
parser.add_argument(
"--weighting_scheme",
type=str,
default="uniform",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
)
parser.add_argument(
"--logit_mean",
type=float,
default=0.0,
help="mean to use when using the `'logit_normal'` weighting scheme.",
)
parser.add_argument(
"--logit_std",
type=float,
default=1.0,
help="std to use when using the `'logit_normal'` weighting scheme.",
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
# lr_scheduler
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of cycles in the learning rate scheduler.",
)
parser.add_argument(
"--lr_power",
type=float,
default=1.0,
help="Power factor of the polynomial scheduler.",
)
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to apply.")
parser.add_argument(
"--master_weight_type",
type=str,
default="fp32",
help="Weight type to use - fp32 or bf16.",
)
args = parser.parse_args()
main(args)
FastVideo-main/fastvideo/train_back.py
0 → 100644
View file @
c07946d8
# !/bin/python3
# isort: skip_file
import
argparse
import
math
import
os
import
time
from
collections
import
deque
import
torch
import
torch.distributed
as
dist
import
wandb
from
accelerate.utils
import
set_seed
from
diffusers
import
FlowMatchEulerDiscreteScheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
,
convert_unet_state_dict_to_peft
from
peft
import
LoraConfig
,
set_peft_model_state_dict
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm.auto
import
tqdm
from
fastvideo.dataset.latent_datasets
import
(
LatentDataset
,
latent_collate_function
)
from
fastvideo.models.mochi_hf.mochi_latents_utils
import
normalize_dit_input
from
fastvideo.models.mochi_hf.pipeline_mochi
import
MochiPipeline
from
fastvideo.models.hunyuan_hf.pipeline_hunyuan
import
HunyuanVideoPipeline
from
fastvideo.utils.checkpoint
import
(
resume_lora_optimizer
,
save_checkpoint
,
save_lora_checkpoint
)
from
fastvideo.utils.communications
import
(
broadcast
,
sp_parallel_dataloader_wrapper
)
from
fastvideo.utils.dataset_utils
import
LengthGroupedSampler
from
fastvideo.utils.fsdp_util
import
(
apply_fsdp_checkpointing
,
get_dit_fsdp_kwargs
)
from
fastvideo.utils.load
import
load_transformer
from
fastvideo.utils.logging_
import
main_print
from
fastvideo.utils.parallel_states
import
(
destroy_sequence_parallel_group
,
get_sequence_parallel_state
,
initialize_sequence_parallel_state
)
from
fastvideo.utils.validation
import
log_validation
import
torch
torch
.
_dynamo
.
config
.
capture_scalar_outputs
=
True
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version
(
"0.31.0"
)
def
compute_density_for_timestep_sampling
(
weighting_scheme
:
str
,
batch_size
:
int
,
generator
,
logit_mean
:
float
=
None
,
logit_std
:
float
=
None
,
mode_scale
:
float
=
None
,
):
"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if
weighting_scheme
==
"logit_normal"
:
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u
=
torch
.
normal
(
mean
=
logit_mean
,
std
=
logit_std
,
size
=
(
batch_size
,
),
device
=
"cpu"
,
generator
=
generator
,
)
u
=
torch
.
nn
.
functional
.
sigmoid
(
u
)
elif
weighting_scheme
==
"mode"
:
u
=
torch
.
rand
(
size
=
(
batch_size
,
),
device
=
"cpu"
,
generator
=
generator
)
u
=
1
-
u
-
mode_scale
*
(
torch
.
cos
(
math
.
pi
*
u
/
2
)
**
2
-
1
+
u
)
else
:
u
=
torch
.
rand
(
size
=
(
batch_size
,
),
device
=
"cpu"
,
generator
=
generator
)
return
u
def
get_sigmas
(
noise_scheduler
,
device
,
timesteps
,
n_dim
=
4
,
dtype
=
torch
.
float32
):
sigmas
=
noise_scheduler
.
sigmas
.
to
(
device
=
device
,
dtype
=
dtype
)
schedule_timesteps
=
noise_scheduler
.
timesteps
.
to
(
device
)
timesteps
=
timesteps
.
to
(
device
)
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
n_dim
:
sigma
=
sigma
.
unsqueeze
(
-
1
)
return
sigma
def
train_one_step
(
transformer
,
model_type
,
optimizer
,
lr_scheduler
,
loader
,
noise_scheduler
,
noise_random_generator
,
gradient_accumulation_steps
,
sp_size
,
precondition_outputs
,
max_grad_norm
,
weighting_scheme
,
logit_mean
,
logit_std
,
mode_scale
,
):
total_loss
=
0.0
optimizer
.
zero_grad
()
for
_
in
range
(
gradient_accumulation_steps
):
(
latents
,
encoder_hidden_states
,
latents_attention_mask
,
encoder_attention_mask
,
)
=
next
(
loader
)
latents
=
normalize_dit_input
(
model_type
,
latents
)
batch_size
=
latents
.
shape
[
0
]
noise
=
torch
.
randn_like
(
latents
)
u
=
compute_density_for_timestep_sampling
(
weighting_scheme
=
weighting_scheme
,
batch_size
=
batch_size
,
generator
=
noise_random_generator
,
logit_mean
=
logit_mean
,
logit_std
=
logit_std
,
mode_scale
=
mode_scale
,
)
indices
=
(
u
*
noise_scheduler
.
config
.
num_train_timesteps
).
long
()
timesteps
=
noise_scheduler
.
timesteps
[
indices
].
to
(
device
=
latents
.
device
)
if
sp_size
>
1
:
# Make sure that the timesteps are the same across all sp processes.
broadcast
(
timesteps
)
sigmas
=
get_sigmas
(
noise_scheduler
,
latents
.
device
,
timesteps
,
n_dim
=
latents
.
ndim
,
dtype
=
latents
.
dtype
,
)
noisy_model_input
=
(
1.0
-
sigmas
)
*
latents
+
sigmas
*
noise
with
torch
.
autocast
(
"cuda"
,
dtype
=
torch
.
bfloat16
):
input_kwargs
=
{
"hidden_states"
:
noisy_model_input
,
"encoder_hidden_states"
:
encoder_hidden_states
,
"timestep"
:
timesteps
,
"encoder_attention_mask"
:
encoder_attention_mask
,
# B, L
"return_dict"
:
False
,
}
if
'hunyuan'
in
model_type
:
input_kwargs
[
"guidance"
]
=
torch
.
tensor
([
1000.0
],
device
=
noisy_model_input
.
device
,
dtype
=
torch
.
bfloat16
)
model_pred
=
transformer
(
**
input_kwargs
)[
0
]
if
precondition_outputs
:
model_pred
=
noisy_model_input
-
model_pred
*
sigmas
if
precondition_outputs
:
target
=
latents
else
:
target
=
noise
-
latents
loss
=
(
torch
.
mean
((
model_pred
.
float
()
-
target
.
float
())
**
2
)
/
gradient_accumulation_steps
)
loss
.
backward
()
avg_loss
=
loss
.
detach
().
clone
()
dist
.
all_reduce
(
avg_loss
,
op
=
dist
.
ReduceOp
.
AVG
)
total_loss
+=
avg_loss
.
item
()
grad_norm
=
transformer
.
clip_grad_norm_
(
max_grad_norm
)
optimizer
.
step
()
lr_scheduler
.
step
()
return
total_loss
,
grad_norm
.
item
()
def
main
(
args
):
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
local_rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
rank
=
int
(
os
.
environ
[
"RANK"
])
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
dist
.
init_process_group
(
"nccl"
)
torch
.
cuda
.
set_device
(
local_rank
)
device
=
torch
.
cuda
.
current_device
()
initialize_sequence_parallel_state
(
args
.
sp_size
)
# If passed along, set the training seed now. On GPU...
if
args
.
seed
is
not
None
:
# TODO: t within the same seq parallel group should be the same. Noise should be different.
set_seed
(
args
.
seed
+
rank
)
# We use different seeds for the noise generation in each process to ensure that the noise is different in a batch.
noise_random_generator
=
None
# Handle the repository creation
if
rank
<=
0
and
args
.
output_dir
is
not
None
:
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# For mixed precision training we cast all non-trainable weights to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
# Create model:
main_print
(
f
"--> loading model from
{
args
.
pretrained_model_name_or_path
}
"
)
# keep the master weight to float32
print
(
"<"
*
50
)
transformer
=
load_transformer
(
args
.
model_type
,
args
.
dit_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
torch
.
float32
if
args
.
master_weight_type
==
"fp32"
else
torch
.
bfloat16
,
)
print
(
">"
*
50
)
if
args
.
use_lora
:
assert
args
.
model_type
!=
"hunyuan"
,
"LoRA is only supported for huggingface model. Please use hunyuan_hf for lora finetuning"
if
args
.
model_type
==
"mochi"
:
pipe
=
MochiPipeline
elif
args
.
model_type
==
"hunyuan_hf"
:
pipe
=
HunyuanVideoPipeline
transformer
.
requires_grad_
(
False
)
transformer_lora_config
=
LoraConfig
(
r
=
args
.
lora_rank
,
lora_alpha
=
args
.
lora_alpha
,
init_lora_weights
=
True
,
target_modules
=
[
"to_k"
,
"to_q"
,
"to_v"
,
"to_out.0"
],
)
transformer
.
add_adapter
(
transformer_lora_config
)
if
args
.
resume_from_lora_checkpoint
:
lora_state_dict
=
pipe
.
lora_state_dict
(
args
.
resume_from_lora_checkpoint
)
transformer_state_dict
=
{
f
'
{
k
.
replace
(
"transformer."
,
""
)
}
'
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"transformer."
)
}
transformer_state_dict
=
convert_unet_state_dict_to_peft
(
transformer_state_dict
)
incompatible_keys
=
set_peft_model_state_dict
(
transformer
,
transformer_state_dict
,
adapter_name
=
"default"
)
if
incompatible_keys
is
not
None
:
# check only for unexpected keys
unexpected_keys
=
getattr
(
incompatible_keys
,
"unexpected_keys"
,
None
)
if
unexpected_keys
:
main_print
(
f
"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f
"
{
unexpected_keys
}
. "
)
main_print
(
f
" Total training parameters =
{
sum
(
p
.
numel
()
for
p
in
transformer
.
parameters
()
if
p
.
requires_grad
)
/
1e6
}
M"
)
main_print
(
f
"--> Initializing FSDP with sharding strategy:
{
args
.
fsdp_sharding_startegy
}
"
)
fsdp_kwargs
,
no_split_modules
=
get_dit_fsdp_kwargs
(
transformer
,
args
.
fsdp_sharding_startegy
,
args
.
use_lora
,
args
.
use_cpu_offload
,
args
.
master_weight_type
,
)
if
args
.
use_lora
:
transformer
.
config
.
lora_rank
=
args
.
lora_rank
transformer
.
config
.
lora_alpha
=
args
.
lora_alpha
transformer
.
config
.
lora_target_modules
=
[
"to_k"
,
"to_q"
,
"to_v"
,
"to_out.0"
]
transformer
.
_no_split_modules
=
[
no_split_module
.
__name__
for
no_split_module
in
no_split_modules
]
fsdp_kwargs
[
"auto_wrap_policy"
]
=
fsdp_kwargs
[
"auto_wrap_policy"
](
transformer
)
fsdp_kwargs
[
'use_orig_params'
]
=
True
transformer
=
FSDP
(
transformer
,
**
fsdp_kwargs
,
)
main_print
(
"--> model loaded"
)
if
args
.
gradient_checkpointing
:
apply_fsdp_checkpointing
(
transformer
,
no_split_modules
,
args
.
selective_checkpointing
)
#transformer = torch.compile(transformer)
# Set model as trainable.
transformer
.
train
()
noise_scheduler
=
FlowMatchEulerDiscreteScheduler
()
params_to_optimize
=
transformer
.
parameters
()
params_to_optimize
=
list
(
filter
(
lambda
p
:
p
.
requires_grad
,
params_to_optimize
))
optimizer
=
torch
.
optim
.
AdamW
(
params_to_optimize
,
lr
=
args
.
learning_rate
,
betas
=
(
0.9
,
0.999
),
weight_decay
=
args
.
weight_decay
,
eps
=
1e-8
,
)
init_steps
=
0
if
args
.
resume_from_lora_checkpoint
:
transformer
,
optimizer
,
init_steps
=
resume_lora_optimizer
(
transformer
,
args
.
resume_from_lora_checkpoint
,
optimizer
)
main_print
(
f
"optimizer:
{
optimizer
}
"
)
lr_scheduler
=
get_scheduler
(
args
.
lr_scheduler
,
optimizer
=
optimizer
,
num_warmup_steps
=
args
.
lr_warmup_steps
,
num_training_steps
=
args
.
max_train_steps
,
num_cycles
=
args
.
lr_num_cycles
,
power
=
args
.
lr_power
,
last_epoch
=
init_steps
-
1
,
)
train_dataset
=
LatentDataset
(
args
.
data_json_path
,
args
.
num_latent_t
,
args
.
cfg
)
sampler
=
(
LengthGroupedSampler
(
args
.
train_batch_size
,
rank
=
rank
,
world_size
=
world_size
,
lengths
=
train_dataset
.
lengths
,
group_frame
=
args
.
group_frame
,
group_resolution
=
args
.
group_resolution
,
)
if
(
args
.
group_frame
or
args
.
group_resolution
)
else
DistributedSampler
(
train_dataset
,
rank
=
rank
,
num_replicas
=
world_size
,
shuffle
=
False
))
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
sampler
,
collate_fn
=
latent_collate_function
,
pin_memory
=
True
,
batch_size
=
args
.
train_batch_size
,
num_workers
=
args
.
dataloader_num_workers
,
drop_last
=
True
,
)
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
)
/
args
.
gradient_accumulation_steps
*
args
.
sp_size
/
args
.
train_sp_batch_size
)
args
.
num_train_epochs
=
math
.
ceil
(
args
.
max_train_steps
/
num_update_steps_per_epoch
)
# if rank <= 0:
# project = args.tracker_project_name or "fastvideo"
# wandb.init(project=project, config=args)
# Train!
total_batch_size
=
(
world_size
*
args
.
gradient_accumulation_steps
/
args
.
sp_size
*
args
.
train_sp_batch_size
)
main_print
(
"***** Running training *****"
)
main_print
(
f
" Num examples =
{
len
(
train_dataset
)
}
"
)
main_print
(
f
" Dataloader size =
{
len
(
train_dataloader
)
}
"
)
main_print
(
f
" Num Epochs =
{
args
.
num_train_epochs
}
"
)
main_print
(
f
" Resume training from step
{
init_steps
}
"
)
main_print
(
f
" Instantaneous batch size per device =
{
args
.
train_batch_size
}
"
)
main_print
(
f
" Total train batch size (w. data & sequence parallel, accumulation) =
{
total_batch_size
}
"
)
main_print
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
main_print
(
f
" Total optimization steps =
{
args
.
max_train_steps
}
"
)
main_print
(
f
" Total training parameters per FSDP shard =
{
sum
(
p
.
numel
()
for
p
in
transformer
.
parameters
()
if
p
.
requires_grad
)
/
1e9
}
B"
)
# print dtype
main_print
(
f
" Master weight dtype:
{
transformer
.
parameters
().
__next__
().
dtype
}
"
)
# Potentially load in the weights and states from a previous save
if
args
.
resume_from_checkpoint
:
assert
NotImplementedError
(
"resume_from_checkpoint is not supported now."
)
# TODO
progress_bar
=
tqdm
(
range
(
0
,
args
.
max_train_steps
),
initial
=
init_steps
,
desc
=
"Steps"
,
# Only show the progress bar once on each machine.
disable
=
local_rank
>
0
,
)
loader
=
sp_parallel_dataloader_wrapper
(
train_dataloader
,
device
,
args
.
train_batch_size
,
args
.
sp_size
,
args
.
train_sp_batch_size
,
)
step_times
=
deque
(
maxlen
=
100
)
# todo future
for
i
in
range
(
init_steps
):
next
(
loader
)
for
step
in
range
(
init_steps
+
1
,
args
.
max_train_steps
+
1
):
start_time
=
time
.
time
()
loss
,
grad_norm
=
train_one_step
(
transformer
,
args
.
model_type
,
optimizer
,
lr_scheduler
,
loader
,
noise_scheduler
,
noise_random_generator
,
args
.
gradient_accumulation_steps
,
args
.
sp_size
,
args
.
precondition_outputs
,
args
.
max_grad_norm
,
args
.
weighting_scheme
,
args
.
logit_mean
,
args
.
logit_std
,
args
.
mode_scale
,
)
step_time
=
time
.
time
()
-
start_time
step_times
.
append
(
step_time
)
avg_step_time
=
sum
(
step_times
)
/
len
(
step_times
)
progress_bar
.
set_postfix
({
"loss"
:
f
"
{
loss
:.
4
f
}
"
,
"step_time"
:
f
"
{
step_time
:.
2
f
}
s"
,
"grad_norm"
:
grad_norm
,
})
progress_bar
.
update
(
1
)
# if rank <= 0:
# wandb.log(
# {
# "train_loss": loss,
# "learning_rate": lr_scheduler.get_last_lr()[0],
# "step_time": step_time,
# "avg_step_time": avg_step_time,
# "grad_norm": grad_norm,
# },
# step=step,
# )
main_print
(
f
"zll step_time:
{
step_time
:.
2
f
}
s avg_step_time:
{
sum
(
step_times
)
/
len
(
step_times
)
}
"
)
if
step
%
args
.
checkpointing_steps
==
0
:
if
args
.
use_lora
:
# Save LoRA weights
save_lora_checkpoint
(
transformer
,
optimizer
,
rank
,
args
.
output_dir
,
step
,
pipe
)
else
:
# Your existing checkpoint saving code
save_checkpoint
(
transformer
,
rank
,
args
.
output_dir
,
step
)
dist
.
barrier
()
if
args
.
log_validation
and
step
%
args
.
validation_steps
==
0
:
log_validation
(
args
,
transformer
,
device
,
torch
.
bfloat16
,
step
,
shift
=
args
.
shift
)
if
args
.
use_lora
:
save_lora_checkpoint
(
transformer
,
optimizer
,
rank
,
args
.
output_dir
,
args
.
max_train_steps
,
pipe
)
else
:
save_checkpoint
(
transformer
,
rank
,
args
.
output_dir
,
args
.
max_train_steps
)
if
get_sequence_parallel_state
():
destroy_sequence_parallel_group
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_type"
,
type
=
str
,
default
=
"mochi"
,
help
=
"The type of model to train. Currentlt support [mochi, hunyuan_hf, hunyuan]"
)
# dataset & dataloader
parser
.
add_argument
(
"--data_json_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--num_height"
,
type
=
int
,
default
=
480
)
parser
.
add_argument
(
"--num_width"
,
type
=
int
,
default
=
848
)
parser
.
add_argument
(
"--num_frames"
,
type
=
int
,
default
=
163
)
parser
.
add_argument
(
"--dataloader_num_workers"
,
type
=
int
,
default
=
10
,
help
=
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
,
)
parser
.
add_argument
(
"--train_batch_size"
,
type
=
int
,
default
=
16
,
help
=
"Batch size (per device) for the training dataloader."
,
)
parser
.
add_argument
(
"--num_latent_t"
,
type
=
int
,
default
=
28
,
help
=
"Number of latent timesteps."
)
parser
.
add_argument
(
"--group_frame"
,
action
=
"store_true"
)
# TODO
parser
.
add_argument
(
"--group_resolution"
,
action
=
"store_true"
)
# TODO
# text encoder & vae & diffusion model
parser
.
add_argument
(
"--pretrained_model_name_or_path"
,
type
=
str
)
parser
.
add_argument
(
"--dit_model_name_or_path"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--cache_dir"
,
type
=
str
,
default
=
"./cache_dir"
)
# diffusion setting
parser
.
add_argument
(
"--ema_decay"
,
type
=
float
,
default
=
0.999
)
parser
.
add_argument
(
"--ema_start_step"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--cfg"
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
"--precondition_outputs"
,
action
=
"store_true"
,
help
=
"Whether to precondition the outputs of the model."
,
)
# validation & logs
parser
.
add_argument
(
"--validation_prompt_dir"
,
type
=
str
)
parser
.
add_argument
(
"--uncond_prompt_dir"
,
type
=
str
)
parser
.
add_argument
(
"--validation_sampling_steps"
,
type
=
str
,
default
=
"64"
,
help
=
"use ',' to split multi sampling steps"
,
)
parser
.
add_argument
(
"--validation_guidance_scale"
,
type
=
str
,
default
=
"4.5"
,
help
=
"use ',' to split multi scale"
,
)
parser
.
add_argument
(
"--validation_steps"
,
type
=
int
,
default
=
50
)
parser
.
add_argument
(
"--log_validation"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--tracker_project_name"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
None
,
help
=
"A seed for reproducible training."
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
None
,
help
=
"The output directory where the model predictions and checkpoints will be written."
,
)
parser
.
add_argument
(
"--checkpoints_total_limit"
,
type
=
int
,
default
=
None
,
help
=
(
"Max number of checkpoints to store."
),
)
parser
.
add_argument
(
"--checkpointing_steps"
,
type
=
int
,
default
=
500
,
help
=
(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser
.
add_argument
(
"--shift"
,
type
=
float
,
default
=
1.0
,
help
=
(
"Set shift to 7 for hunyuan model."
))
parser
.
add_argument
(
"--resume_from_checkpoint"
,
type
=
str
,
default
=
None
,
help
=
(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser
.
add_argument
(
"--resume_from_lora_checkpoint"
,
type
=
str
,
default
=
None
,
help
=
(
"Whether training should be resumed from a previous lora checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser
.
add_argument
(
"--logging_dir"
,
type
=
str
,
default
=
"logs"
,
help
=
(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
# optimizer & scheduler & Training
parser
.
add_argument
(
"--num_train_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--max_train_steps"
,
type
=
int
,
default
=
None
,
help
=
"Total number of training steps to perform. If provided, overrides num_train_epochs."
,
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
1e-4
,
help
=
"Initial learning rate (after the potential warmup period) to use."
,
)
parser
.
add_argument
(
"--scale_lr"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size."
,
)
parser
.
add_argument
(
"--lr_warmup_steps"
,
type
=
int
,
default
=
10
,
help
=
"Number of steps for the warmup in the lr scheduler."
,
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--gradient_checkpointing"
,
action
=
"store_true"
,
help
=
"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass."
,
)
parser
.
add_argument
(
"--selective_checkpointing"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--allow_tf32"
,
action
=
"store_true"
,
help
=
(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser
.
add_argument
(
"--mixed_precision"
,
type
=
str
,
default
=
None
,
choices
=
[
"no"
,
"fp16"
,
"bf16"
],
help
=
(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser
.
add_argument
(
"--use_cpu_offload"
,
action
=
"store_true"
,
help
=
"Whether to use CPU offload for param & gradient & optimizer states."
,
)
parser
.
add_argument
(
"--sp_size"
,
type
=
int
,
default
=
1
,
help
=
"For sequence parallel"
)
parser
.
add_argument
(
"--train_sp_batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size for sequence parallel training"
,
)
parser
.
add_argument
(
"--use_lora"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use LoRA for finetuning."
,
)
parser
.
add_argument
(
"--lora_alpha"
,
type
=
int
,
default
=
256
,
help
=
"Alpha parameter for LoRA."
)
parser
.
add_argument
(
"--lora_rank"
,
type
=
int
,
default
=
128
,
help
=
"LoRA rank parameter. "
)
parser
.
add_argument
(
"--fsdp_sharding_startegy"
,
default
=
"full"
)
parser
.
add_argument
(
"--weighting_scheme"
,
type
=
str
,
default
=
"uniform"
,
choices
=
[
"sigma_sqrt"
,
"logit_normal"
,
"mode"
,
"cosmap"
,
"uniform"
],
)
parser
.
add_argument
(
"--logit_mean"
,
type
=
float
,
default
=
0.0
,
help
=
"mean to use when using the `'logit_normal'` weighting scheme."
,
)
parser
.
add_argument
(
"--logit_std"
,
type
=
float
,
default
=
1.0
,
help
=
"std to use when using the `'logit_normal'` weighting scheme."
,
)
parser
.
add_argument
(
"--mode_scale"
,
type
=
float
,
default
=
1.29
,
help
=
"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`."
,
)
# lr_scheduler
parser
.
add_argument
(
"--lr_scheduler"
,
type
=
str
,
default
=
"constant"
,
help
=
(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser
.
add_argument
(
"--lr_num_cycles"
,
type
=
int
,
default
=
1
,
help
=
"Number of cycles in the learning rate scheduler."
,
)
parser
.
add_argument
(
"--lr_power"
,
type
=
float
,
default
=
1.0
,
help
=
"Power factor of the polynomial scheduler."
,
)
parser
.
add_argument
(
"--weight_decay"
,
type
=
float
,
default
=
0.01
,
help
=
"Weight decay to apply."
)
parser
.
add_argument
(
"--master_weight_type"
,
type
=
str
,
default
=
"fp32"
,
help
=
"Weight type to use - fp32 or bf16."
,
)
args
=
parser
.
parse_args
()
main
(
args
)
FastVideo-main/fastvideo/train_prof.py
0 → 100644
View file @
c07946d8
# !/bin/python3
# isort: skip_file
import
argparse
import
math
import
os
import
time
from
collections
import
deque
import
torch
import
torch.distributed
as
dist
import
wandb
from
accelerate.utils
import
set_seed
from
diffusers
import
FlowMatchEulerDiscreteScheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
,
convert_unet_state_dict_to_peft
from
peft
import
LoraConfig
,
set_peft_model_state_dict
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm.auto
import
tqdm
from
fastvideo.dataset.latent_datasets
import
(
LatentDataset
,
latent_collate_function
)
from
fastvideo.models.mochi_hf.mochi_latents_utils
import
normalize_dit_input
from
fastvideo.models.mochi_hf.pipeline_mochi
import
MochiPipeline
from
fastvideo.models.hunyuan_hf.pipeline_hunyuan
import
HunyuanVideoPipeline
from
fastvideo.utils.checkpoint
import
(
resume_lora_optimizer
,
save_checkpoint
,
save_lora_checkpoint
)
from
fastvideo.utils.communications
import
(
broadcast
,
sp_parallel_dataloader_wrapper
)
from
fastvideo.utils.dataset_utils
import
LengthGroupedSampler
from
fastvideo.utils.fsdp_util
import
(
apply_fsdp_checkpointing
,
get_dit_fsdp_kwargs
)
from
fastvideo.utils.load
import
load_transformer
from
fastvideo.utils.logging_
import
main_print
from
fastvideo.utils.parallel_states
import
(
destroy_sequence_parallel_group
,
get_sequence_parallel_state
,
initialize_sequence_parallel_state
)
from
fastvideo.utils.validation
import
log_validation
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version
(
"0.31.0"
)
from
torch.profiler
import
profile
,
record_function
,
ProfilerActivity
def
trace_handler
(
p
):
output
=
p
.
key_averages
().
table
(
sort_by
=
"self_cuda_time_total"
,
row_limit
=
50
)
print
(
output
)
rank
=
dist
.
get_rank
()
p
.
export_chrome_trace
(
"/public/hy-code/FastVideo-main/scripts/finetune/prof/BW_amd"
+
str
(
rank
)
+
"_"
+
str
(
p
.
step_num
)
+
".json"
)
def
compute_density_for_timestep_sampling
(
weighting_scheme
:
str
,
batch_size
:
int
,
generator
,
logit_mean
:
float
=
None
,
logit_std
:
float
=
None
,
mode_scale
:
float
=
None
,
):
"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if
weighting_scheme
==
"logit_normal"
:
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u
=
torch
.
normal
(
mean
=
logit_mean
,
std
=
logit_std
,
size
=
(
batch_size
,
),
device
=
"cpu"
,
generator
=
generator
,
)
u
=
torch
.
nn
.
functional
.
sigmoid
(
u
)
elif
weighting_scheme
==
"mode"
:
u
=
torch
.
rand
(
size
=
(
batch_size
,
),
device
=
"cpu"
,
generator
=
generator
)
u
=
1
-
u
-
mode_scale
*
(
torch
.
cos
(
math
.
pi
*
u
/
2
)
**
2
-
1
+
u
)
else
:
u
=
torch
.
rand
(
size
=
(
batch_size
,
),
device
=
"cpu"
,
generator
=
generator
)
return
u
def
get_sigmas
(
noise_scheduler
,
device
,
timesteps
,
n_dim
=
4
,
dtype
=
torch
.
float32
):
sigmas
=
noise_scheduler
.
sigmas
.
to
(
device
=
device
,
dtype
=
dtype
)
schedule_timesteps
=
noise_scheduler
.
timesteps
.
to
(
device
)
timesteps
=
timesteps
.
to
(
device
)
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
n_dim
:
sigma
=
sigma
.
unsqueeze
(
-
1
)
return
sigma
def
train_one_step
(
transformer
,
model_type
,
optimizer
,
lr_scheduler
,
loader
,
noise_scheduler
,
noise_random_generator
,
gradient_accumulation_steps
,
sp_size
,
precondition_outputs
,
max_grad_norm
,
weighting_scheme
,
logit_mean
,
logit_std
,
mode_scale
,
):
total_loss
=
0.0
optimizer
.
zero_grad
()
for
_
in
range
(
gradient_accumulation_steps
):
(
latents
,
encoder_hidden_states
,
latents_attention_mask
,
encoder_attention_mask
,
)
=
next
(
loader
)
latents
=
normalize_dit_input
(
model_type
,
latents
)
batch_size
=
latents
.
shape
[
0
]
noise
=
torch
.
randn_like
(
latents
)
u
=
compute_density_for_timestep_sampling
(
weighting_scheme
=
weighting_scheme
,
batch_size
=
batch_size
,
generator
=
noise_random_generator
,
logit_mean
=
logit_mean
,
logit_std
=
logit_std
,
mode_scale
=
mode_scale
,
)
indices
=
(
u
*
noise_scheduler
.
config
.
num_train_timesteps
).
long
()
timesteps
=
noise_scheduler
.
timesteps
[
indices
].
to
(
device
=
latents
.
device
)
if
sp_size
>
1
:
# Make sure that the timesteps are the same across all sp processes.
broadcast
(
timesteps
)
sigmas
=
get_sigmas
(
noise_scheduler
,
latents
.
device
,
timesteps
,
n_dim
=
latents
.
ndim
,
dtype
=
latents
.
dtype
,
)
noisy_model_input
=
(
1.0
-
sigmas
)
*
latents
+
sigmas
*
noise
with
torch
.
autocast
(
"cuda"
,
dtype
=
torch
.
bfloat16
):
input_kwargs
=
{
"hidden_states"
:
noisy_model_input
,
"encoder_hidden_states"
:
encoder_hidden_states
,
"timestep"
:
timesteps
,
"encoder_attention_mask"
:
encoder_attention_mask
,
# B, L
"return_dict"
:
False
,
}
if
'hunyuan'
in
model_type
:
input_kwargs
[
"guidance"
]
=
torch
.
tensor
([
1000.0
],
device
=
noisy_model_input
.
device
,
dtype
=
torch
.
bfloat16
)
model_pred
=
transformer
(
**
input_kwargs
)[
0
]
if
precondition_outputs
:
model_pred
=
noisy_model_input
-
model_pred
*
sigmas
if
precondition_outputs
:
target
=
latents
else
:
target
=
noise
-
latents
loss
=
(
torch
.
mean
((
model_pred
.
float
()
-
target
.
float
())
**
2
)
/
gradient_accumulation_steps
)
loss
.
backward
()
avg_loss
=
loss
.
detach
().
clone
()
dist
.
all_reduce
(
avg_loss
,
op
=
dist
.
ReduceOp
.
AVG
)
total_loss
+=
avg_loss
.
item
()
grad_norm
=
transformer
.
clip_grad_norm_
(
max_grad_norm
)
optimizer
.
step
()
lr_scheduler
.
step
()
return
total_loss
,
grad_norm
.
item
()
def
main
(
args
):
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
local_rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
rank
=
int
(
os
.
environ
[
"RANK"
])
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
dist
.
init_process_group
(
"nccl"
)
torch
.
cuda
.
set_device
(
local_rank
)
device
=
torch
.
cuda
.
current_device
()
initialize_sequence_parallel_state
(
args
.
sp_size
)
# If passed along, set the training seed now. On GPU...
if
args
.
seed
is
not
None
:
# TODO: t within the same seq parallel group should be the same. Noise should be different.
set_seed
(
args
.
seed
+
rank
)
# We use different seeds for the noise generation in each process to ensure that the noise is different in a batch.
noise_random_generator
=
None
# Handle the repository creation
if
rank
<=
0
and
args
.
output_dir
is
not
None
:
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# For mixed precision training we cast all non-trainable weights to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
# Create model:
main_print
(
f
"--> loading model from
{
args
.
pretrained_model_name_or_path
}
"
)
# keep the master weight to float32
print
(
"<"
*
50
)
transformer
=
load_transformer
(
args
.
model_type
,
args
.
dit_model_name_or_path
,
args
.
pretrained_model_name_or_path
,
torch
.
float32
if
args
.
master_weight_type
==
"fp32"
else
torch
.
bfloat16
,
)
print
(
">"
*
50
)
if
args
.
use_lora
:
assert
args
.
model_type
!=
"hunyuan"
,
"LoRA is only supported for huggingface model. Please use hunyuan_hf for lora finetuning"
if
args
.
model_type
==
"mochi"
:
pipe
=
MochiPipeline
elif
args
.
model_type
==
"hunyuan_hf"
:
pipe
=
HunyuanVideoPipeline
transformer
.
requires_grad_
(
False
)
transformer_lora_config
=
LoraConfig
(
r
=
args
.
lora_rank
,
lora_alpha
=
args
.
lora_alpha
,
init_lora_weights
=
True
,
target_modules
=
[
"to_k"
,
"to_q"
,
"to_v"
,
"to_out.0"
],
)
transformer
.
add_adapter
(
transformer_lora_config
)
if
args
.
resume_from_lora_checkpoint
:
lora_state_dict
=
pipe
.
lora_state_dict
(
args
.
resume_from_lora_checkpoint
)
transformer_state_dict
=
{
f
'
{
k
.
replace
(
"transformer."
,
""
)
}
'
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"transformer."
)
}
transformer_state_dict
=
convert_unet_state_dict_to_peft
(
transformer_state_dict
)
incompatible_keys
=
set_peft_model_state_dict
(
transformer
,
transformer_state_dict
,
adapter_name
=
"default"
)
if
incompatible_keys
is
not
None
:
# check only for unexpected keys
unexpected_keys
=
getattr
(
incompatible_keys
,
"unexpected_keys"
,
None
)
if
unexpected_keys
:
main_print
(
f
"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f
"
{
unexpected_keys
}
. "
)
main_print
(
f
" Total training parameters =
{
sum
(
p
.
numel
()
for
p
in
transformer
.
parameters
()
if
p
.
requires_grad
)
/
1e6
}
M"
)
main_print
(
f
"--> Initializing FSDP with sharding strategy:
{
args
.
fsdp_sharding_startegy
}
"
)
fsdp_kwargs
,
no_split_modules
=
get_dit_fsdp_kwargs
(
transformer
,
args
.
fsdp_sharding_startegy
,
args
.
use_lora
,
args
.
use_cpu_offload
,
args
.
master_weight_type
,
)
if
args
.
use_lora
:
transformer
.
config
.
lora_rank
=
args
.
lora_rank
transformer
.
config
.
lora_alpha
=
args
.
lora_alpha
transformer
.
config
.
lora_target_modules
=
[
"to_k"
,
"to_q"
,
"to_v"
,
"to_out.0"
]
transformer
.
_no_split_modules
=
[
no_split_module
.
__name__
for
no_split_module
in
no_split_modules
]
fsdp_kwargs
[
"auto_wrap_policy"
]
=
fsdp_kwargs
[
"auto_wrap_policy"
](
transformer
)
transformer
=
FSDP
(
transformer
,
**
fsdp_kwargs
,
)
main_print
(
"--> model loaded"
)
if
args
.
gradient_checkpointing
:
apply_fsdp_checkpointing
(
transformer
,
no_split_modules
,
args
.
selective_checkpointing
)
# Set model as trainable.
transformer
.
train
()
noise_scheduler
=
FlowMatchEulerDiscreteScheduler
()
params_to_optimize
=
transformer
.
parameters
()
params_to_optimize
=
list
(
filter
(
lambda
p
:
p
.
requires_grad
,
params_to_optimize
))
optimizer
=
torch
.
optim
.
AdamW
(
params_to_optimize
,
lr
=
args
.
learning_rate
,
betas
=
(
0.9
,
0.999
),
weight_decay
=
args
.
weight_decay
,
eps
=
1e-8
,
)
init_steps
=
0
if
args
.
resume_from_lora_checkpoint
:
transformer
,
optimizer
,
init_steps
=
resume_lora_optimizer
(
transformer
,
args
.
resume_from_lora_checkpoint
,
optimizer
)
main_print
(
f
"optimizer:
{
optimizer
}
"
)
lr_scheduler
=
get_scheduler
(
args
.
lr_scheduler
,
optimizer
=
optimizer
,
num_warmup_steps
=
args
.
lr_warmup_steps
,
num_training_steps
=
args
.
max_train_steps
,
num_cycles
=
args
.
lr_num_cycles
,
power
=
args
.
lr_power
,
last_epoch
=
init_steps
-
1
,
)
train_dataset
=
LatentDataset
(
args
.
data_json_path
,
args
.
num_latent_t
,
args
.
cfg
)
sampler
=
(
LengthGroupedSampler
(
args
.
train_batch_size
,
rank
=
rank
,
world_size
=
world_size
,
lengths
=
train_dataset
.
lengths
,
group_frame
=
args
.
group_frame
,
group_resolution
=
args
.
group_resolution
,
)
if
(
args
.
group_frame
or
args
.
group_resolution
)
else
DistributedSampler
(
train_dataset
,
rank
=
rank
,
num_replicas
=
world_size
,
shuffle
=
False
))
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
sampler
,
collate_fn
=
latent_collate_function
,
pin_memory
=
True
,
batch_size
=
args
.
train_batch_size
,
num_workers
=
args
.
dataloader_num_workers
,
drop_last
=
True
,
)
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
)
/
args
.
gradient_accumulation_steps
*
args
.
sp_size
/
args
.
train_sp_batch_size
)
args
.
num_train_epochs
=
math
.
ceil
(
args
.
max_train_steps
/
num_update_steps_per_epoch
)
# if rank <= 0:
# project = args.tracker_project_name or "fastvideo"
# wandb.init(project=project, config=args)
# Train!
total_batch_size
=
(
world_size
*
args
.
gradient_accumulation_steps
/
args
.
sp_size
*
args
.
train_sp_batch_size
)
main_print
(
"***** Running training *****"
)
main_print
(
f
" Num examples =
{
len
(
train_dataset
)
}
"
)
main_print
(
f
" Dataloader size =
{
len
(
train_dataloader
)
}
"
)
main_print
(
f
" Num Epochs =
{
args
.
num_train_epochs
}
"
)
main_print
(
f
" Resume training from step
{
init_steps
}
"
)
main_print
(
f
" Instantaneous batch size per device =
{
args
.
train_batch_size
}
"
)
main_print
(
f
" Total train batch size (w. data & sequence parallel, accumulation) =
{
total_batch_size
}
"
)
main_print
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
main_print
(
f
" Total optimization steps =
{
args
.
max_train_steps
}
"
)
main_print
(
f
" Total training parameters per FSDP shard =
{
sum
(
p
.
numel
()
for
p
in
transformer
.
parameters
()
if
p
.
requires_grad
)
/
1e9
}
B"
)
# print dtype
main_print
(
f
" Master weight dtype:
{
transformer
.
parameters
().
__next__
().
dtype
}
"
)
# Potentially load in the weights and states from a previous save
if
args
.
resume_from_checkpoint
:
assert
NotImplementedError
(
"resume_from_checkpoint is not supported now."
)
# TODO
progress_bar
=
tqdm
(
range
(
0
,
args
.
max_train_steps
),
initial
=
init_steps
,
desc
=
"Steps"
,
# Only show the progress bar once on each machine.
disable
=
local_rank
>
0
,
)
loader
=
sp_parallel_dataloader_wrapper
(
train_dataloader
,
device
,
args
.
train_batch_size
,
args
.
sp_size
,
args
.
train_sp_batch_size
,
)
step_times
=
deque
(
maxlen
=
100
)
# todo future
for
i
in
range
(
init_steps
):
next
(
loader
)
with
profile
(
activities
=
[
ProfilerActivity
.
CPU
,
ProfilerActivity
.
CUDA
],
schedule
=
torch
.
profiler
.
schedule
(
wait
=
10
,
warmup
=
5
,
active
=
1
),
on_trace_ready
=
trace_handler
)
as
p
:
for
step
in
range
(
init_steps
+
1
,
args
.
max_train_steps
+
1
):
start_time
=
time
.
time
()
loss
,
grad_norm
=
train_one_step
(
transformer
,
args
.
model_type
,
optimizer
,
lr_scheduler
,
loader
,
noise_scheduler
,
noise_random_generator
,
args
.
gradient_accumulation_steps
,
args
.
sp_size
,
args
.
precondition_outputs
,
args
.
max_grad_norm
,
args
.
weighting_scheme
,
args
.
logit_mean
,
args
.
logit_std
,
args
.
mode_scale
,
)
p
.
step
()
step_time
=
time
.
time
()
-
start_time
step_times
.
append
(
step_time
)
avg_step_time
=
sum
(
step_times
)
/
len
(
step_times
)
progress_bar
.
set_postfix
({
"loss"
:
f
"
{
loss
:.
4
f
}
"
,
"step_time"
:
f
"
{
step_time
:.
2
f
}
s"
,
"grad_norm"
:
grad_norm
,
})
progress_bar
.
update
(
1
)
# if rank <= 0:
# wandb.log(
# {
# "train_loss": loss,
# "learning_rate": lr_scheduler.get_last_lr()[0],
# "step_time": step_time,
# "avg_step_time": avg_step_time,
# "grad_norm": grad_norm,
# },
# step=step,
# )
if
step
%
args
.
checkpointing_steps
==
0
:
if
args
.
use_lora
:
# Save LoRA weights
save_lora_checkpoint
(
transformer
,
optimizer
,
rank
,
args
.
output_dir
,
step
,
pipe
)
else
:
# Your existing checkpoint saving code
save_checkpoint
(
transformer
,
rank
,
args
.
output_dir
,
step
)
dist
.
barrier
()
if
args
.
log_validation
and
step
%
args
.
validation_steps
==
0
:
log_validation
(
args
,
transformer
,
device
,
torch
.
bfloat16
,
step
,
shift
=
args
.
shift
)
if
args
.
use_lora
:
save_lora_checkpoint
(
transformer
,
optimizer
,
rank
,
args
.
output_dir
,
args
.
max_train_steps
,
pipe
)
else
:
save_checkpoint
(
transformer
,
rank
,
args
.
output_dir
,
args
.
max_train_steps
)
if
get_sequence_parallel_state
():
destroy_sequence_parallel_group
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_type"
,
type
=
str
,
default
=
"mochi"
,
help
=
"The type of model to train. Currentlt support [mochi, hunyuan_hf, hunyuan]"
)
# dataset & dataloader
parser
.
add_argument
(
"--data_json_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--num_height"
,
type
=
int
,
default
=
480
)
parser
.
add_argument
(
"--num_width"
,
type
=
int
,
default
=
848
)
parser
.
add_argument
(
"--num_frames"
,
type
=
int
,
default
=
163
)
parser
.
add_argument
(
"--dataloader_num_workers"
,
type
=
int
,
default
=
10
,
help
=
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
,
)
parser
.
add_argument
(
"--train_batch_size"
,
type
=
int
,
default
=
16
,
help
=
"Batch size (per device) for the training dataloader."
,
)
parser
.
add_argument
(
"--num_latent_t"
,
type
=
int
,
default
=
28
,
help
=
"Number of latent timesteps."
)
parser
.
add_argument
(
"--group_frame"
,
action
=
"store_true"
)
# TODO
parser
.
add_argument
(
"--group_resolution"
,
action
=
"store_true"
)
# TODO
# text encoder & vae & diffusion model
parser
.
add_argument
(
"--pretrained_model_name_or_path"
,
type
=
str
)
parser
.
add_argument
(
"--dit_model_name_or_path"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--cache_dir"
,
type
=
str
,
default
=
"./cache_dir"
)
# diffusion setting
parser
.
add_argument
(
"--ema_decay"
,
type
=
float
,
default
=
0.999
)
parser
.
add_argument
(
"--ema_start_step"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--cfg"
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
"--precondition_outputs"
,
action
=
"store_true"
,
help
=
"Whether to precondition the outputs of the model."
,
)
# validation & logs
parser
.
add_argument
(
"--validation_prompt_dir"
,
type
=
str
)
parser
.
add_argument
(
"--uncond_prompt_dir"
,
type
=
str
)
parser
.
add_argument
(
"--validation_sampling_steps"
,
type
=
str
,
default
=
"64"
,
help
=
"use ',' to split multi sampling steps"
,
)
parser
.
add_argument
(
"--validation_guidance_scale"
,
type
=
str
,
default
=
"4.5"
,
help
=
"use ',' to split multi scale"
,
)
parser
.
add_argument
(
"--validation_steps"
,
type
=
int
,
default
=
50
)
parser
.
add_argument
(
"--log_validation"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--tracker_project_name"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
None
,
help
=
"A seed for reproducible training."
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
None
,
help
=
"The output directory where the model predictions and checkpoints will be written."
,
)
parser
.
add_argument
(
"--checkpoints_total_limit"
,
type
=
int
,
default
=
None
,
help
=
(
"Max number of checkpoints to store."
),
)
parser
.
add_argument
(
"--checkpointing_steps"
,
type
=
int
,
default
=
500
,
help
=
(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser
.
add_argument
(
"--shift"
,
type
=
float
,
default
=
1.0
,
help
=
(
"Set shift to 7 for hunyuan model."
))
parser
.
add_argument
(
"--resume_from_checkpoint"
,
type
=
str
,
default
=
None
,
help
=
(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser
.
add_argument
(
"--resume_from_lora_checkpoint"
,
type
=
str
,
default
=
None
,
help
=
(
"Whether training should be resumed from a previous lora checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser
.
add_argument
(
"--logging_dir"
,
type
=
str
,
default
=
"logs"
,
help
=
(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
# optimizer & scheduler & Training
parser
.
add_argument
(
"--num_train_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--max_train_steps"
,
type
=
int
,
default
=
None
,
help
=
"Total number of training steps to perform. If provided, overrides num_train_epochs."
,
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
1e-4
,
help
=
"Initial learning rate (after the potential warmup period) to use."
,
)
parser
.
add_argument
(
"--scale_lr"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size."
,
)
parser
.
add_argument
(
"--lr_warmup_steps"
,
type
=
int
,
default
=
10
,
help
=
"Number of steps for the warmup in the lr scheduler."
,
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--gradient_checkpointing"
,
action
=
"store_true"
,
help
=
"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass."
,
)
parser
.
add_argument
(
"--selective_checkpointing"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--allow_tf32"
,
action
=
"store_true"
,
help
=
(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser
.
add_argument
(
"--mixed_precision"
,
type
=
str
,
default
=
None
,
choices
=
[
"no"
,
"fp16"
,
"bf16"
],
help
=
(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser
.
add_argument
(
"--use_cpu_offload"
,
action
=
"store_true"
,
help
=
"Whether to use CPU offload for param & gradient & optimizer states."
,
)
parser
.
add_argument
(
"--sp_size"
,
type
=
int
,
default
=
1
,
help
=
"For sequence parallel"
)
parser
.
add_argument
(
"--train_sp_batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size for sequence parallel training"
,
)
parser
.
add_argument
(
"--use_lora"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use LoRA for finetuning."
,
)
parser
.
add_argument
(
"--lora_alpha"
,
type
=
int
,
default
=
256
,
help
=
"Alpha parameter for LoRA."
)
parser
.
add_argument
(
"--lora_rank"
,
type
=
int
,
default
=
128
,
help
=
"LoRA rank parameter. "
)
parser
.
add_argument
(
"--fsdp_sharding_startegy"
,
default
=
"full"
)
parser
.
add_argument
(
"--weighting_scheme"
,
type
=
str
,
default
=
"uniform"
,
choices
=
[
"sigma_sqrt"
,
"logit_normal"
,
"mode"
,
"cosmap"
,
"uniform"
],
)
parser
.
add_argument
(
"--logit_mean"
,
type
=
float
,
default
=
0.0
,
help
=
"mean to use when using the `'logit_normal'` weighting scheme."
,
)
parser
.
add_argument
(
"--logit_std"
,
type
=
float
,
default
=
1.0
,
help
=
"std to use when using the `'logit_normal'` weighting scheme."
,
)
parser
.
add_argument
(
"--mode_scale"
,
type
=
float
,
default
=
1.29
,
help
=
"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`."
,
)
# lr_scheduler
parser
.
add_argument
(
"--lr_scheduler"
,
type
=
str
,
default
=
"constant"
,
help
=
(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser
.
add_argument
(
"--lr_num_cycles"
,
type
=
int
,
default
=
1
,
help
=
"Number of cycles in the learning rate scheduler."
,
)
parser
.
add_argument
(
"--lr_power"
,
type
=
float
,
default
=
1.0
,
help
=
"Power factor of the polynomial scheduler."
,
)
parser
.
add_argument
(
"--weight_decay"
,
type
=
float
,
default
=
0.01
,
help
=
"Weight decay to apply."
)
parser
.
add_argument
(
"--master_weight_type"
,
type
=
str
,
default
=
"fp32"
,
help
=
"Weight type to use - fp32 or bf16."
,
)
args
=
parser
.
parse_args
()
main
(
args
)
FastVideo-main/fastvideo/utils/checkpoint.py
0 → 100644
View file @
c07946d8
# import
import
json
import
os
import
torch
import
torch.distributed.checkpoint
as
dist_cp
from
peft
import
get_peft_model_state_dict
from
safetensors.torch
import
load_file
,
save_file
from
torch.distributed.checkpoint.default_planner
import
DefaultLoadPlanner
,
DefaultSavePlanner
from
torch.distributed.checkpoint.optimizer
import
load_sharded_optimizer_state_dict
from
torch.distributed.fsdp
import
FullOptimStateDictConfig
,
FullStateDictConfig
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp
import
StateDictType
from
fastvideo.utils.logging_
import
main_print
def
save_checkpoint_optimizer
(
model
,
optimizer
,
rank
,
output_dir
,
step
,
discriminator
=
False
):
with
FSDP
.
state_dict_type
(
model
,
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
),
FullOptimStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
),
):
cpu_state
=
model
.
state_dict
()
optim_state
=
FSDP
.
optim_state_dict
(
model
,
optimizer
,
)
# todo move to get_state_dict
save_dir
=
os
.
path
.
join
(
output_dir
,
f
"checkpoint-
{
step
}
"
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
# save using safetensors
if
rank
<=
0
and
not
discriminator
:
weight_path
=
os
.
path
.
join
(
save_dir
,
"diffusion_pytorch_model.safetensors"
)
save_file
(
cpu_state
,
weight_path
)
config_dict
=
dict
(
model
.
config
)
config_dict
.
pop
(
'dtype'
)
config_path
=
os
.
path
.
join
(
save_dir
,
"config.json"
)
# save dict as json
with
open
(
config_path
,
"w"
)
as
f
:
json
.
dump
(
config_dict
,
f
,
indent
=
4
)
optimizer_path
=
os
.
path
.
join
(
save_dir
,
"optimizer.pt"
)
torch
.
save
(
optim_state
,
optimizer_path
)
else
:
weight_path
=
os
.
path
.
join
(
save_dir
,
"discriminator_pytorch_model.safetensors"
)
save_file
(
cpu_state
,
weight_path
)
optimizer_path
=
os
.
path
.
join
(
save_dir
,
"discriminator_optimizer.pt"
)
torch
.
save
(
optim_state
,
optimizer_path
)
main_print
(
f
"--> checkpoint saved at step
{
step
}
"
)
def
save_checkpoint
(
transformer
,
rank
,
output_dir
,
step
):
main_print
(
f
"--> saving checkpoint at step
{
step
}
"
)
with
FSDP
.
state_dict_type
(
transformer
,
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
),
):
cpu_state
=
transformer
.
state_dict
()
# todo move to get_state_dict
if
rank
<=
0
:
save_dir
=
os
.
path
.
join
(
output_dir
,
f
"checkpoint-
{
step
}
"
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
# save using safetensors
weight_path
=
os
.
path
.
join
(
save_dir
,
"diffusion_pytorch_model.safetensors"
)
save_file
(
cpu_state
,
weight_path
)
config_dict
=
dict
(
transformer
.
config
)
if
"dtype"
in
config_dict
:
del
config_dict
[
"dtype"
]
# TODO
config_path
=
os
.
path
.
join
(
save_dir
,
"config.json"
)
# save dict as json
with
open
(
config_path
,
"w"
)
as
f
:
json
.
dump
(
config_dict
,
f
,
indent
=
4
)
main_print
(
f
"--> checkpoint saved at step
{
step
}
"
)
def
save_checkpoint_generator_discriminator
(
model
,
optimizer
,
discriminator
,
discriminator_optimizer
,
rank
,
output_dir
,
step
,
):
with
FSDP
.
state_dict_type
(
model
,
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
),
):
cpu_state
=
model
.
state_dict
()
# todo move to get_state_dict
save_dir
=
os
.
path
.
join
(
output_dir
,
f
"checkpoint-
{
step
}
"
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
hf_weight_dir
=
os
.
path
.
join
(
save_dir
,
"hf_weights"
)
os
.
makedirs
(
hf_weight_dir
,
exist_ok
=
True
)
# save using safetensors
if
rank
<=
0
:
config_dict
=
dict
(
model
.
config
)
config_path
=
os
.
path
.
join
(
hf_weight_dir
,
"config.json"
)
# save dict as json
with
open
(
config_path
,
"w"
)
as
f
:
json
.
dump
(
config_dict
,
f
,
indent
=
4
)
weight_path
=
os
.
path
.
join
(
hf_weight_dir
,
"diffusion_pytorch_model.safetensors"
)
save_file
(
cpu_state
,
weight_path
)
main_print
(
f
"--> saved HF weight checkpoint at path
{
hf_weight_dir
}
"
)
model_weight_dir
=
os
.
path
.
join
(
save_dir
,
"model_weights_state"
)
os
.
makedirs
(
model_weight_dir
,
exist_ok
=
True
)
model_optimizer_dir
=
os
.
path
.
join
(
save_dir
,
"model_optimizer_state"
)
os
.
makedirs
(
model_optimizer_dir
,
exist_ok
=
True
)
with
FSDP
.
state_dict_type
(
model
,
StateDictType
.
SHARDED_STATE_DICT
):
optim_state
=
FSDP
.
optim_state_dict
(
model
,
optimizer
)
model_state
=
model
.
state_dict
()
weight_state_dict
=
{
"model"
:
model_state
}
dist_cp
.
save_state_dict
(
state_dict
=
weight_state_dict
,
storage_writer
=
dist_cp
.
FileSystemWriter
(
model_weight_dir
),
planner
=
DefaultSavePlanner
(),
)
optimizer_state_dict
=
{
"optimizer"
:
optim_state
}
dist_cp
.
save_state_dict
(
state_dict
=
optimizer_state_dict
,
storage_writer
=
dist_cp
.
FileSystemWriter
(
model_optimizer_dir
),
planner
=
DefaultSavePlanner
(),
)
discriminator_fsdp_state_dir
=
os
.
path
.
join
(
save_dir
,
"discriminator_fsdp_state"
)
os
.
makedirs
(
discriminator_fsdp_state_dir
,
exist_ok
=
True
)
with
FSDP
.
state_dict_type
(
discriminator
,
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
),
FullOptimStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
),
):
optim_state
=
FSDP
.
optim_state_dict
(
discriminator
,
discriminator_optimizer
)
model_state
=
discriminator
.
state_dict
()
state_dict
=
{
"optimizer"
:
optim_state
,
"model"
:
model_state
}
if
rank
<=
0
:
discriminator_fsdp_state_fil
=
os
.
path
.
join
(
discriminator_fsdp_state_dir
,
"discriminator_state.pt"
)
torch
.
save
(
state_dict
,
discriminator_fsdp_state_fil
)
main_print
(
"--> saved FSDP state checkpoint"
)
def
load_sharded_model
(
model
,
optimizer
,
model_dir
,
optimizer_dir
):
with
FSDP
.
state_dict_type
(
model
,
StateDictType
.
SHARDED_STATE_DICT
):
weight_state_dict
=
{
"model"
:
model
.
state_dict
()}
optim_state
=
load_sharded_optimizer_state_dict
(
model_state_dict
=
weight_state_dict
[
"model"
],
optimizer_key
=
"optimizer"
,
storage_reader
=
dist_cp
.
FileSystemReader
(
optimizer_dir
),
)
optim_state
=
optim_state
[
"optimizer"
]
flattened_osd
=
FSDP
.
optim_state_dict_to_load
(
model
=
model
,
optim
=
optimizer
,
optim_state_dict
=
optim_state
)
optimizer
.
load_state_dict
(
flattened_osd
)
dist_cp
.
load_state_dict
(
state_dict
=
weight_state_dict
,
storage_reader
=
dist_cp
.
FileSystemReader
(
model_dir
),
planner
=
DefaultLoadPlanner
(),
)
model_state
=
weight_state_dict
[
"model"
]
model
.
load_state_dict
(
model_state
)
main_print
(
f
"--> loaded model and optimizer from path
{
model_dir
}
"
)
return
model
,
optimizer
def
load_full_state_model
(
model
,
optimizer
,
checkpoint_file
,
rank
):
with
FSDP
.
state_dict_type
(
model
,
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
),
FullOptimStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
),
):
discriminator_state
=
torch
.
load
(
checkpoint_file
)
model_state
=
discriminator_state
[
"model"
]
if
rank
<=
0
:
optim_state
=
discriminator_state
[
"optimizer"
]
else
:
optim_state
=
None
model
.
load_state_dict
(
model_state
)
discriminator_optim_state
=
FSDP
.
optim_state_dict_to_load
(
model
=
model
,
optim
=
optimizer
,
optim_state_dict
=
optim_state
)
optimizer
.
load_state_dict
(
discriminator_optim_state
)
main_print
(
f
"--> loaded discriminator and discriminator optimizer from path
{
checkpoint_file
}
"
)
return
model
,
optimizer
def
resume_training_generator_discriminator
(
model
,
optimizer
,
discriminator
,
discriminator_optimizer
,
checkpoint_dir
,
rank
):
step
=
int
(
checkpoint_dir
.
split
(
"-"
)[
-
1
])
model_weight_dir
=
os
.
path
.
join
(
checkpoint_dir
,
"model_weights_state"
)
model_optimizer_dir
=
os
.
path
.
join
(
checkpoint_dir
,
"model_optimizer_state"
)
model
,
optimizer
=
load_sharded_model
(
model
,
optimizer
,
model_weight_dir
,
model_optimizer_dir
)
discriminator_ckpt_file
=
os
.
path
.
join
(
checkpoint_dir
,
"discriminator_fsdp_state"
,
"discriminator_state.pt"
)
discriminator
,
discriminator_optimizer
=
load_full_state_model
(
discriminator
,
discriminator_optimizer
,
discriminator_ckpt_file
,
rank
)
return
model
,
optimizer
,
discriminator
,
discriminator_optimizer
,
step
def
resume_training
(
model
,
optimizer
,
checkpoint_dir
,
discriminator
=
False
):
weight_path
=
os
.
path
.
join
(
checkpoint_dir
,
"diffusion_pytorch_model.safetensors"
)
if
discriminator
:
weight_path
=
os
.
path
.
join
(
checkpoint_dir
,
"discriminator_pytorch_model.safetensors"
)
model_weights
=
load_file
(
weight_path
)
with
FSDP
.
state_dict_type
(
model
,
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
),
FullOptimStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
),
):
current_state
=
model
.
state_dict
()
current_state
.
update
(
model_weights
)
model
.
load_state_dict
(
current_state
,
strict
=
False
)
if
discriminator
:
optim_path
=
os
.
path
.
join
(
checkpoint_dir
,
"discriminator_optimizer.pt"
)
else
:
optim_path
=
os
.
path
.
join
(
checkpoint_dir
,
"optimizer.pt"
)
optimizer_state_dict
=
torch
.
load
(
optim_path
,
weights_only
=
False
)
optim_state
=
FSDP
.
optim_state_dict_to_load
(
model
=
model
,
optim
=
optimizer
,
optim_state_dict
=
optimizer_state_dict
)
optimizer
.
load_state_dict
(
optim_state
)
step
=
int
(
checkpoint_dir
.
split
(
"-"
)[
-
1
])
return
model
,
optimizer
,
step
def
save_lora_checkpoint
(
transformer
,
optimizer
,
rank
,
output_dir
,
step
,
pipeline
):
with
FSDP
.
state_dict_type
(
transformer
,
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
),
):
full_state_dict
=
transformer
.
state_dict
()
lora_optim_state
=
FSDP
.
optim_state_dict
(
transformer
,
optimizer
,
)
if
rank
<=
0
:
save_dir
=
os
.
path
.
join
(
output_dir
,
f
"lora-checkpoint-
{
step
}
"
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
# save optimizer
optim_path
=
os
.
path
.
join
(
save_dir
,
"lora_optimizer.pt"
)
torch
.
save
(
lora_optim_state
,
optim_path
)
# save lora weight
main_print
(
f
"--> saving LoRA checkpoint at step
{
step
}
"
)
transformer_lora_layers
=
get_peft_model_state_dict
(
model
=
transformer
,
state_dict
=
full_state_dict
)
pipeline
.
save_lora_weights
(
save_directory
=
save_dir
,
transformer_lora_layers
=
transformer_lora_layers
,
is_main_process
=
True
,
)
# save config
lora_config
=
{
"step"
:
step
,
"lora_params"
:
{
"lora_rank"
:
transformer
.
config
.
lora_rank
,
"lora_alpha"
:
transformer
.
config
.
lora_alpha
,
"target_modules"
:
transformer
.
config
.
lora_target_modules
,
},
}
config_path
=
os
.
path
.
join
(
save_dir
,
"lora_config.json"
)
with
open
(
config_path
,
"w"
)
as
f
:
json
.
dump
(
lora_config
,
f
,
indent
=
4
)
main_print
(
f
"--> LoRA checkpoint saved at step
{
step
}
"
)
def
resume_lora_optimizer
(
transformer
,
checkpoint_dir
,
optimizer
):
config_path
=
os
.
path
.
join
(
checkpoint_dir
,
"lora_config.json"
)
with
open
(
config_path
,
"r"
)
as
f
:
config_dict
=
json
.
load
(
f
)
optim_path
=
os
.
path
.
join
(
checkpoint_dir
,
"lora_optimizer.pt"
)
optimizer_state_dict
=
torch
.
load
(
optim_path
,
weights_only
=
False
)
optim_state
=
FSDP
.
optim_state_dict_to_load
(
model
=
transformer
,
optim
=
optimizer
,
optim_state_dict
=
optimizer_state_dict
)
optimizer
.
load_state_dict
(
optim_state
)
step
=
config_dict
[
"step"
]
main_print
(
f
"--> Successfully resuming LoRA optimizer from step
{
step
}
"
)
return
transformer
,
optimizer
,
step
FastVideo-main/fastvideo/utils/communications.py
0 → 100644
View file @
c07946d8
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from
typing
import
Any
,
Tuple
import
torch
import
torch.distributed
as
dist
from
torch
import
Tensor
from
fastvideo.utils.parallel_states
import
nccl_info
def
broadcast
(
input_
:
torch
.
Tensor
):
src
=
nccl_info
.
group_id
*
nccl_info
.
sp_size
dist
.
broadcast
(
input_
,
src
=
src
,
group
=
nccl_info
.
group
)
def
_all_to_all_4D
(
input
:
torch
.
tensor
,
scatter_idx
:
int
=
2
,
gather_idx
:
int
=
1
,
group
=
None
)
->
torch
.
tensor
:
"""
all-to-all for QKV
Args:
input (torch.tensor): a tensor sharded along dim scatter dim
scatter_idx (int): default 1
gather_idx (int): default 2
group : torch process group
Returns:
torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
"""
assert
(
input
.
dim
()
==
4
),
f
"input must be 4D tensor, got
{
input
.
dim
()
}
and shape
{
input
.
shape
}
"
seq_world_size
=
dist
.
get_world_size
(
group
)
if
scatter_idx
==
2
and
gather_idx
==
1
:
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
bs
,
shard_seqlen
,
hc
,
hs
=
input
.
shape
seqlen
=
shard_seqlen
*
seq_world_size
shard_hc
=
hc
//
seq_world_size
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
# (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
input_t
=
(
input
.
reshape
(
bs
,
shard_seqlen
,
seq_world_size
,
shard_hc
,
hs
).
transpose
(
0
,
2
).
contiguous
())
output
=
torch
.
empty_like
(
input_t
)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head
if
seq_world_size
>
1
:
dist
.
all_to_all_single
(
output
,
input_t
,
group
=
group
,
async_op
=
True
)
torch
.
cuda
.
synchronize
()
else
:
output
=
input_t
# if scattering the seq-dim, transpose the heads back to the original dimension
output
=
output
.
reshape
(
seqlen
,
bs
,
shard_hc
,
hs
)
# (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
output
=
output
.
transpose
(
0
,
1
).
contiguous
().
reshape
(
bs
,
seqlen
,
shard_hc
,
hs
)
return
output
elif
scatter_idx
==
1
and
gather_idx
==
2
:
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
bs
,
seqlen
,
shard_hc
,
hs
=
input
.
shape
hc
=
shard_hc
*
seq_world_size
shard_seqlen
=
seqlen
//
seq_world_size
seq_world_size
=
dist
.
get_world_size
(
group
)
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
# (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
input_t
=
(
input
.
reshape
(
bs
,
seq_world_size
,
shard_seqlen
,
shard_hc
,
hs
).
transpose
(
0
,
3
).
transpose
(
0
,
1
).
contiguous
().
reshape
(
seq_world_size
,
shard_hc
,
shard_seqlen
,
bs
,
hs
))
output
=
torch
.
empty_like
(
input_t
)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
if
seq_world_size
>
1
:
dist
.
all_to_all_single
(
output
,
input_t
,
group
=
group
)
torch
.
cuda
.
synchronize
()
else
:
output
=
input_t
# if scattering the seq-dim, transpose the heads back to the original dimension
output
=
output
.
reshape
(
hc
,
shard_seqlen
,
bs
,
hs
)
# (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
output
=
output
.
transpose
(
0
,
2
).
contiguous
().
reshape
(
bs
,
shard_seqlen
,
hc
,
hs
)
return
output
else
:
raise
RuntimeError
(
"scatter_idx must be 1 or 2 and gather_idx must be 1 or 2"
)
class
SeqAllToAll4D
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
:
Any
,
group
:
dist
.
ProcessGroup
,
input
:
Tensor
,
scatter_idx
:
int
,
gather_idx
:
int
,
)
->
Tensor
:
ctx
.
group
=
group
ctx
.
scatter_idx
=
scatter_idx
ctx
.
gather_idx
=
gather_idx
return
_all_to_all_4D
(
input
,
scatter_idx
,
gather_idx
,
group
=
group
)
@
staticmethod
def
backward
(
ctx
:
Any
,
*
grad_output
:
Tensor
)
->
Tuple
[
None
,
Tensor
,
None
,
None
]:
return
(
None
,
SeqAllToAll4D
.
apply
(
ctx
.
group
,
*
grad_output
,
ctx
.
gather_idx
,
ctx
.
scatter_idx
),
None
,
None
,
)
def
all_to_all_4D
(
input_
:
torch
.
Tensor
,
scatter_dim
:
int
=
2
,
gather_dim
:
int
=
1
,
):
return
SeqAllToAll4D
.
apply
(
nccl_info
.
group
,
input_
,
scatter_dim
,
gather_dim
)
def
_all_to_all
(
input_
:
torch
.
Tensor
,
world_size
:
int
,
group
:
dist
.
ProcessGroup
,
scatter_dim
:
int
,
gather_dim
:
int
,
):
input_list
=
[
t
.
contiguous
()
for
t
in
torch
.
tensor_split
(
input_
,
world_size
,
scatter_dim
)]
output_list
=
[
torch
.
empty_like
(
input_list
[
0
])
for
_
in
range
(
world_size
)]
dist
.
all_to_all
(
output_list
,
input_list
,
group
=
group
)
return
torch
.
cat
(
output_list
,
dim
=
gather_dim
).
contiguous
()
class
_AllToAll
(
torch
.
autograd
.
Function
):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@
staticmethod
def
forward
(
ctx
,
input_
,
process_group
,
scatter_dim
,
gather_dim
):
ctx
.
process_group
=
process_group
ctx
.
scatter_dim
=
scatter_dim
ctx
.
gather_dim
=
gather_dim
ctx
.
world_size
=
dist
.
get_world_size
(
process_group
)
output
=
_all_to_all
(
input_
,
ctx
.
world_size
,
process_group
,
scatter_dim
,
gather_dim
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
grad_output
=
_all_to_all
(
grad_output
,
ctx
.
world_size
,
ctx
.
process_group
,
ctx
.
gather_dim
,
ctx
.
scatter_dim
,
)
return
(
grad_output
,
None
,
None
,
None
,
)
def
all_to_all
(
input_
:
torch
.
Tensor
,
scatter_dim
:
int
=
2
,
gather_dim
:
int
=
1
,
):
return
_AllToAll
.
apply
(
input_
,
nccl_info
.
group
,
scatter_dim
,
gather_dim
)
class
_AllGather
(
torch
.
autograd
.
Function
):
"""All-gather communication with autograd support.
Args:
input_: input tensor
dim: dimension along which to concatenate
"""
@
staticmethod
def
forward
(
ctx
,
input_
,
dim
):
ctx
.
dim
=
dim
world_size
=
nccl_info
.
sp_size
group
=
nccl_info
.
group
input_size
=
list
(
input_
.
size
())
ctx
.
input_size
=
input_size
[
dim
]
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
input_
=
input_
.
contiguous
()
dist
.
all_gather
(
tensor_list
,
input_
,
group
=
group
)
output
=
torch
.
cat
(
tensor_list
,
dim
=
dim
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
world_size
=
nccl_info
.
sp_size
rank
=
nccl_info
.
rank_within_group
dim
=
ctx
.
dim
input_size
=
ctx
.
input_size
sizes
=
[
input_size
]
*
world_size
grad_input_list
=
torch
.
split
(
grad_output
,
sizes
,
dim
=
dim
)
grad_input
=
grad_input_list
[
rank
]
return
grad_input
,
None
def
all_gather
(
input_
:
torch
.
Tensor
,
dim
:
int
=
1
):
"""Performs an all-gather operation on the input tensor along the specified dimension.
Args:
input_ (torch.Tensor): Input tensor of shape [B, H, S, D].
dim (int, optional): Dimension along which to concatenate. Defaults to 1.
Returns:
torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'.
"""
return
_AllGather
.
apply
(
input_
,
dim
)
def
prepare_sequence_parallel_data
(
hidden_states
,
encoder_hidden_states
,
attention_mask
,
encoder_attention_mask
):
if
nccl_info
.
sp_size
==
1
:
return
(
hidden_states
,
encoder_hidden_states
,
attention_mask
,
encoder_attention_mask
,
)
def
prepare
(
hidden_states
,
encoder_hidden_states
,
attention_mask
,
encoder_attention_mask
):
hidden_states
=
all_to_all
(
hidden_states
,
scatter_dim
=
2
,
gather_dim
=
0
)
encoder_hidden_states
=
all_to_all
(
encoder_hidden_states
,
scatter_dim
=
1
,
gather_dim
=
0
)
attention_mask
=
all_to_all
(
attention_mask
,
scatter_dim
=
1
,
gather_dim
=
0
)
encoder_attention_mask
=
all_to_all
(
encoder_attention_mask
,
scatter_dim
=
1
,
gather_dim
=
0
)
return
(
hidden_states
,
encoder_hidden_states
,
attention_mask
,
encoder_attention_mask
,
)
sp_size
=
nccl_info
.
sp_size
frame
=
hidden_states
.
shape
[
2
]
assert
frame
%
sp_size
==
0
,
"frame should be a multiple of sp_size"
(
hidden_states
,
encoder_hidden_states
,
attention_mask
,
encoder_attention_mask
,
)
=
prepare
(
hidden_states
,
encoder_hidden_states
.
repeat
(
1
,
sp_size
,
1
),
attention_mask
.
repeat
(
1
,
sp_size
,
1
,
1
),
encoder_attention_mask
.
repeat
(
1
,
sp_size
),
)
return
hidden_states
,
encoder_hidden_states
,
attention_mask
,
encoder_attention_mask
def
sp_parallel_dataloader_wrapper
(
dataloader
,
device
,
train_batch_size
,
sp_size
,
train_sp_batch_size
):
while
True
:
for
data_item
in
dataloader
:
latents
,
cond
,
attn_mask
,
cond_mask
=
data_item
latents
=
latents
.
to
(
device
)
cond
=
cond
.
to
(
device
)
attn_mask
=
attn_mask
.
to
(
device
)
cond_mask
=
cond_mask
.
to
(
device
)
frame
=
latents
.
shape
[
2
]
if
frame
==
1
:
yield
latents
,
cond
,
attn_mask
,
cond_mask
else
:
latents
,
cond
,
attn_mask
,
cond_mask
=
prepare_sequence_parallel_data
(
latents
,
cond
,
attn_mask
,
cond_mask
)
assert
(
train_batch_size
*
sp_size
>=
train_sp_batch_size
),
"train_batch_size * sp_size should be greater than train_sp_batch_size"
for
iter
in
range
(
train_batch_size
*
sp_size
//
train_sp_batch_size
):
st_idx
=
iter
*
train_sp_batch_size
ed_idx
=
(
iter
+
1
)
*
train_sp_batch_size
encoder_hidden_states
=
cond
[
st_idx
:
ed_idx
]
attention_mask
=
attn_mask
[
st_idx
:
ed_idx
]
encoder_attention_mask
=
cond_mask
[
st_idx
:
ed_idx
]
yield
(
latents
[
st_idx
:
ed_idx
],
encoder_hidden_states
,
attention_mask
,
encoder_attention_mask
,
)
FastVideo-main/fastvideo/utils/dataset_utils.py
0 → 100644
View file @
c07946d8
import
math
import
random
from
collections
import
Counter
from
typing
import
List
,
Optional
import
decord
import
torch
import
torch.utils
import
torch.utils.data
from
torch.nn
import
functional
as
F
from
torch.utils.data
import
Sampler
IMG_EXTENSIONS
=
[
".jpg"
,
".JPG"
,
".jpeg"
,
".JPEG"
,
".png"
,
".PNG"
]
def
is_image_file
(
filename
):
return
any
(
filename
.
endswith
(
extension
)
for
extension
in
IMG_EXTENSIONS
)
class
DecordInit
(
object
):
"""Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
def
__init__
(
self
,
num_threads
=
1
):
self
.
num_threads
=
num_threads
self
.
ctx
=
decord
.
cpu
(
0
)
def
__call__
(
self
,
filename
):
"""Perform the Decord initialization.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
reader
=
decord
.
VideoReader
(
filename
,
ctx
=
self
.
ctx
,
num_threads
=
self
.
num_threads
)
return
reader
def
__repr__
(
self
):
repr_str
=
(
f
"
{
self
.
__class__
.
__name__
}
("
f
"sr=
{
self
.
sr
}
,"
f
"num_threads=
{
self
.
num_threads
}
)"
)
return
repr_str
def
pad_to_multiple
(
number
,
ds_stride
):
remainder
=
number
%
ds_stride
if
remainder
==
0
:
return
number
else
:
padding
=
ds_stride
-
remainder
return
number
+
padding
# TODO
class
Collate
:
def
__init__
(
self
,
args
):
self
.
batch_size
=
args
.
train_batch_size
self
.
group_frame
=
args
.
group_frame
self
.
group_resolution
=
args
.
group_resolution
self
.
max_height
=
args
.
max_height
self
.
max_width
=
args
.
max_width
self
.
ae_stride
=
args
.
ae_stride
self
.
ae_stride_t
=
args
.
ae_stride_t
self
.
ae_stride_thw
=
(
self
.
ae_stride_t
,
self
.
ae_stride
,
self
.
ae_stride
)
self
.
patch_size
=
args
.
patch_size
self
.
patch_size_t
=
args
.
patch_size_t
self
.
num_frames
=
args
.
num_frames
self
.
use_image_num
=
args
.
use_image_num
self
.
max_thw
=
(
self
.
num_frames
,
self
.
max_height
,
self
.
max_width
)
def
package
(
self
,
batch
):
batch_tubes
=
[
i
[
"pixel_values"
]
for
i
in
batch
]
# b [c t h w]
input_ids
=
[
i
[
"input_ids"
]
for
i
in
batch
]
# b [1 l]
cond_mask
=
[
i
[
"cond_mask"
]
for
i
in
batch
]
# b [1 l]
return
batch_tubes
,
input_ids
,
cond_mask
def
__call__
(
self
,
batch
):
batch_tubes
,
input_ids
,
cond_mask
=
self
.
package
(
batch
)
ds_stride
=
self
.
ae_stride
*
self
.
patch_size
t_ds_stride
=
self
.
ae_stride_t
*
self
.
patch_size_t
pad_batch_tubes
,
attention_mask
,
input_ids
,
cond_mask
=
self
.
process
(
batch_tubes
,
input_ids
,
cond_mask
,
t_ds_stride
,
ds_stride
,
self
.
max_thw
,
self
.
ae_stride_thw
,
)
assert
not
torch
.
any
(
torch
.
isnan
(
pad_batch_tubes
)),
"after pad_batch_tubes"
return
pad_batch_tubes
,
attention_mask
,
input_ids
,
cond_mask
def
process
(
self
,
batch_tubes
,
input_ids
,
cond_mask
,
t_ds_stride
,
ds_stride
,
max_thw
,
ae_stride_thw
,
):
# pad to max multiple of ds_stride
batch_input_size
=
[
i
.
shape
for
i
in
batch_tubes
]
# [(c t h w), (c t h w)]
assert
len
(
batch_input_size
)
==
self
.
batch_size
if
self
.
group_frame
or
self
.
group_resolution
or
self
.
batch_size
==
1
:
#
len_each_batch
=
batch_input_size
idx_length_dict
=
dict
([
*
zip
(
list
(
range
(
self
.
batch_size
)),
len_each_batch
)])
count_dict
=
Counter
(
len_each_batch
)
if
len
(
count_dict
)
!=
1
:
sorted_by_value
=
sorted
(
count_dict
.
items
(),
key
=
lambda
item
:
item
[
1
])
pick_length
=
sorted_by_value
[
-
1
][
0
]
# the highest frequency
candidate_batch
=
[
idx
for
idx
,
length
in
idx_length_dict
.
items
()
if
length
==
pick_length
]
random_select_batch
=
[
random
.
choice
(
candidate_batch
)
for
_
in
range
(
len
(
len_each_batch
)
-
len
(
candidate_batch
))
]
print
(
batch_input_size
,
idx_length_dict
,
count_dict
,
sorted_by_value
,
pick_length
,
candidate_batch
,
random_select_batch
,
)
pick_idx
=
candidate_batch
+
random_select_batch
batch_tubes
=
[
batch_tubes
[
i
]
for
i
in
pick_idx
]
batch_input_size
=
[
i
.
shape
for
i
in
batch_tubes
]
# [(c t h w), (c t h w)]
input_ids
=
[
input_ids
[
i
]
for
i
in
pick_idx
]
# b [1, l]
cond_mask
=
[
cond_mask
[
i
]
for
i
in
pick_idx
]
# b [1, l]
for
i
in
range
(
1
,
self
.
batch_size
):
assert
batch_input_size
[
0
]
==
batch_input_size
[
i
]
max_t
=
max
([
i
[
1
]
for
i
in
batch_input_size
])
max_h
=
max
([
i
[
2
]
for
i
in
batch_input_size
])
max_w
=
max
([
i
[
3
]
for
i
in
batch_input_size
])
else
:
max_t
,
max_h
,
max_w
=
max_thw
pad_max_t
,
pad_max_h
,
pad_max_w
=
(
pad_to_multiple
(
max_t
-
1
+
self
.
ae_stride_t
,
t_ds_stride
),
pad_to_multiple
(
max_h
,
ds_stride
),
pad_to_multiple
(
max_w
,
ds_stride
),
)
pad_max_t
=
pad_max_t
+
1
-
self
.
ae_stride_t
each_pad_t_h_w
=
[[
pad_max_t
-
i
.
shape
[
1
],
pad_max_h
-
i
.
shape
[
2
],
pad_max_w
-
i
.
shape
[
3
]]
for
i
in
batch_tubes
]
pad_batch_tubes
=
[
F
.
pad
(
im
,
(
0
,
pad_w
,
0
,
pad_h
,
0
,
pad_t
),
value
=
0
)
for
(
pad_t
,
pad_h
,
pad_w
),
im
in
zip
(
each_pad_t_h_w
,
batch_tubes
)
]
pad_batch_tubes
=
torch
.
stack
(
pad_batch_tubes
,
dim
=
0
)
max_tube_size
=
[
pad_max_t
,
pad_max_h
,
pad_max_w
]
max_latent_size
=
[
((
max_tube_size
[
0
]
-
1
)
//
ae_stride_thw
[
0
]
+
1
),
max_tube_size
[
1
]
//
ae_stride_thw
[
1
],
max_tube_size
[
2
]
//
ae_stride_thw
[
2
],
]
valid_latent_size
=
[[
int
(
math
.
ceil
((
i
[
1
]
-
1
)
/
ae_stride_thw
[
0
]))
+
1
,
int
(
math
.
ceil
(
i
[
2
]
/
ae_stride_thw
[
1
])),
int
(
math
.
ceil
(
i
[
3
]
/
ae_stride_thw
[
2
])),
]
for
i
in
batch_input_size
]
attention_mask
=
[
F
.
pad
(
torch
.
ones
(
i
,
dtype
=
pad_batch_tubes
.
dtype
),
(
0
,
max_latent_size
[
2
]
-
i
[
2
],
0
,
max_latent_size
[
1
]
-
i
[
1
],
0
,
max_latent_size
[
0
]
-
i
[
0
],
),
value
=
0
,
)
for
i
in
valid_latent_size
]
attention_mask
=
torch
.
stack
(
attention_mask
)
# b t h w
if
self
.
batch_size
==
1
or
self
.
group_frame
or
self
.
group_resolution
:
assert
torch
.
all
(
attention_mask
.
bool
())
input_ids
=
torch
.
stack
(
input_ids
)
# b 1 l
cond_mask
=
torch
.
stack
(
cond_mask
)
# b 1 l
return
pad_batch_tubes
,
attention_mask
,
input_ids
,
cond_mask
def
split_to_even_chunks
(
indices
,
lengths
,
num_chunks
,
batch_size
):
"""
Split a list of indices into `chunks` chunks of roughly equal lengths.
"""
if
len
(
indices
)
%
num_chunks
!=
0
:
chunks
=
[
indices
[
i
::
num_chunks
]
for
i
in
range
(
num_chunks
)]
else
:
num_indices_per_chunk
=
len
(
indices
)
//
num_chunks
chunks
=
[[]
for
_
in
range
(
num_chunks
)]
chunks_lengths
=
[
0
for
_
in
range
(
num_chunks
)]
for
index
in
indices
:
shortest_chunk
=
chunks_lengths
.
index
(
min
(
chunks_lengths
))
chunks
[
shortest_chunk
].
append
(
index
)
chunks_lengths
[
shortest_chunk
]
+=
lengths
[
index
]
if
len
(
chunks
[
shortest_chunk
])
==
num_indices_per_chunk
:
chunks_lengths
[
shortest_chunk
]
=
float
(
"inf"
)
# return chunks
pad_chunks
=
[]
for
idx
,
chunk
in
enumerate
(
chunks
):
if
batch_size
!=
len
(
chunk
):
assert
batch_size
>
len
(
chunk
)
if
len
(
chunk
)
!=
0
:
chunk
=
chunk
+
[
random
.
choice
(
chunk
)
for
_
in
range
(
batch_size
-
len
(
chunk
))]
else
:
chunk
=
random
.
choice
(
pad_chunks
)
print
(
chunks
[
idx
],
"->"
,
chunk
)
pad_chunks
.
append
(
chunk
)
return
pad_chunks
def
group_frame_fun
(
indices
,
lengths
):
# sort by num_frames
indices
.
sort
(
key
=
lambda
i
:
lengths
[
i
],
reverse
=
True
)
return
indices
def
megabatch_frame_alignment
(
megabatches
,
lengths
):
aligned_magabatches
=
[]
for
_
,
megabatch
in
enumerate
(
megabatches
):
assert
len
(
megabatch
)
!=
0
len_each_megabatch
=
[
lengths
[
i
]
for
i
in
megabatch
]
idx_length_dict
=
dict
([
*
zip
(
megabatch
,
len_each_megabatch
)])
count_dict
=
Counter
(
len_each_megabatch
)
# mixed frame length, align megabatch inside
if
len
(
count_dict
)
!=
1
:
sorted_by_value
=
sorted
(
count_dict
.
items
(),
key
=
lambda
item
:
item
[
1
])
pick_length
=
sorted_by_value
[
-
1
][
0
]
# the highest frequency
candidate_batch
=
[
idx
for
idx
,
length
in
idx_length_dict
.
items
()
if
length
==
pick_length
]
random_select_batch
=
[
random
.
choice
(
candidate_batch
)
for
i
in
range
(
len
(
idx_length_dict
)
-
len
(
candidate_batch
))
]
aligned_magabatch
=
candidate_batch
+
random_select_batch
aligned_magabatches
.
append
(
aligned_magabatch
)
# already aligned megabatches
else
:
aligned_magabatches
.
append
(
megabatch
)
return
aligned_magabatches
def
get_length_grouped_indices
(
lengths
,
batch_size
,
world_size
,
generator
=
None
,
group_frame
=
False
,
group_resolution
=
False
,
seed
=
42
,
):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
if
generator
is
None
:
generator
=
torch
.
Generator
().
manual_seed
(
seed
)
# every rank will generate a fixed order but random index
indices
=
torch
.
randperm
(
len
(
lengths
),
generator
=
generator
).
tolist
()
# sort dataset according to frame
indices
=
group_frame_fun
(
indices
,
lengths
)
# chunk dataset to megabatches
megabatch_size
=
world_size
*
batch_size
megabatches
=
[
indices
[
i
:
i
+
megabatch_size
]
for
i
in
range
(
0
,
len
(
lengths
),
megabatch_size
)]
# make sure the length in each magabatch is align with each other
megabatches
=
megabatch_frame_alignment
(
megabatches
,
lengths
)
# aplit aligned megabatch into batches
megabatches
=
[
split_to_even_chunks
(
megabatch
,
lengths
,
world_size
,
batch_size
)
for
megabatch
in
megabatches
]
# random megabatches to do video-image mix training
indices
=
torch
.
randperm
(
len
(
megabatches
),
generator
=
generator
).
tolist
()
shuffled_megabatches
=
[
megabatches
[
i
]
for
i
in
indices
]
# expand indices and return
return
[
i
for
megabatch
in
shuffled_megabatches
for
batch
in
megabatch
for
i
in
batch
]
class
LengthGroupedSampler
(
Sampler
):
r
"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def
__init__
(
self
,
batch_size
:
int
,
rank
:
int
,
world_size
:
int
,
lengths
:
Optional
[
List
[
int
]]
=
None
,
group_frame
=
False
,
group_resolution
=
False
,
generator
=
None
,
):
if
lengths
is
None
:
raise
ValueError
(
"Lengths must be provided."
)
self
.
batch_size
=
batch_size
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
lengths
=
lengths
self
.
group_frame
=
group_frame
self
.
group_resolution
=
group_resolution
self
.
generator
=
generator
def
__len__
(
self
):
return
len
(
self
.
lengths
)
def
__iter__
(
self
):
indices
=
get_length_grouped_indices
(
self
.
lengths
,
self
.
batch_size
,
self
.
world_size
,
group_frame
=
self
.
group_frame
,
group_resolution
=
self
.
group_resolution
,
generator
=
self
.
generator
,
)
def
distributed_sampler
(
lst
,
rank
,
batch_size
,
world_size
):
result
=
[]
index
=
rank
*
batch_size
while
index
<
len
(
lst
):
result
.
extend
(
lst
[
index
:
index
+
batch_size
])
index
+=
batch_size
*
world_size
return
result
indices
=
distributed_sampler
(
indices
,
self
.
rank
,
self
.
batch_size
,
self
.
world_size
)
return
iter
(
indices
)
FastVideo-main/fastvideo/utils/env_utils.py
0 → 100644
View file @
c07946d8
import
platform
import
accelerate
import
peft
import
torch
import
transformers
from
transformers.utils
import
is_torch_cuda_available
,
is_torch_npu_available
VERSION
=
"1.2.0"
if
__name__
==
"__main__"
:
info
=
{
"FastVideo version"
:
VERSION
,
"Platform"
:
platform
.
platform
(),
"Python version"
:
platform
.
python_version
(),
"PyTorch version"
:
torch
.
__version__
,
"Transformers version"
:
transformers
.
__version__
,
"Accelerate version"
:
accelerate
.
__version__
,
"PEFT version"
:
peft
.
__version__
,
}
if
is_torch_cuda_available
():
info
[
"PyTorch version"
]
+=
" (GPU)"
info
[
"GPU type"
]
=
torch
.
cuda
.
get_device_name
()
if
is_torch_npu_available
():
info
[
"PyTorch version"
]
+=
" (NPU)"
info
[
"NPU type"
]
=
torch
.
npu
.
get_device_name
()
info
[
"CANN version"
]
=
torch
.
version
.
cann
# codespell:ignore
try
:
import
bitsandbytes
info
[
"Bitsandbytes version"
]
=
bitsandbytes
.
__version__
except
Exception
:
pass
print
(
"
\n
"
+
"
\n
"
.
join
([
f
"-
{
key
}
:
{
value
}
"
for
key
,
value
in
info
.
items
()])
+
"
\n
"
)
FastVideo-main/fastvideo/utils/fsdp_util.py
0 → 100644
View file @
c07946d8
# ruff: noqa: E731
import
functools
from
functools
import
partial
import
torch
from
peft.utils.other
import
fsdp_auto_wrap_policy
from
torch.distributed.algorithms._checkpoint.checkpoint_wrapper
import
(
CheckpointImpl
,
apply_activation_checkpointing
,
checkpoint_wrapper
)
from
torch.distributed.fsdp
import
MixedPrecision
,
ShardingStrategy
from
torch.distributed.fsdp.wrap
import
transformer_auto_wrap_policy
from
fastvideo.models.mochi_hf.modeling_mochi
import
MochiTransformerBlock
from
fastvideo.utils.load
import
get_no_split_modules
non_reentrant_wrapper
=
partial
(
checkpoint_wrapper
,
checkpoint_impl
=
CheckpointImpl
.
NO_REENTRANT
,
)
check_fn
=
lambda
submodule
:
isinstance
(
submodule
,
MochiTransformerBlock
)
def
apply_fsdp_checkpointing
(
model
,
no_split_modules
,
p
=
1
):
# https://github.com/foundation-model-stack/fms-fsdp/blob/408c7516d69ea9b6bcd4c0f5efab26c0f64b3c2d/fms_fsdp/policies/ac_handler.py#L16
"""apply activation checkpointing to model
returns None as model is updated directly
"""
print
(
"--> applying fdsp activation checkpointing..."
)
block_idx
=
0
cut_off
=
1
/
2
# when passing p as a fraction number (e.g. 1/3), it will be interpreted
# as a string in argv, thus we need eval("1/3") here for fractions.
p
=
eval
(
p
)
if
isinstance
(
p
,
str
)
else
p
def
selective_checkpointing
(
submodule
):
nonlocal
block_idx
nonlocal
cut_off
if
isinstance
(
submodule
,
no_split_modules
):
block_idx
+=
1
if
block_idx
*
p
>=
cut_off
:
cut_off
+=
1
return
True
return
False
apply_activation_checkpointing
(
model
,
checkpoint_wrapper_fn
=
non_reentrant_wrapper
,
check_fn
=
selective_checkpointing
,
)
def
get_mixed_precision
(
master_weight_type
=
"fp32"
):
weight_type
=
torch
.
float32
if
master_weight_type
==
"fp32"
else
torch
.
bfloat16
mixed_precision
=
MixedPrecision
(
param_dtype
=
weight_type
,
# Gradient communication precision.
reduce_dtype
=
weight_type
,
# Buffer precision.
buffer_dtype
=
weight_type
,
cast_forward_inputs
=
False
,
)
return
mixed_precision
def
get_dit_fsdp_kwargs
(
transformer
,
sharding_strategy
,
use_lora
=
False
,
cpu_offload
=
False
,
master_weight_type
=
"fp32"
,
):
no_split_modules
=
get_no_split_modules
(
transformer
)
if
use_lora
:
auto_wrap_policy
=
fsdp_auto_wrap_policy
else
:
auto_wrap_policy
=
functools
.
partial
(
transformer_auto_wrap_policy
,
transformer_layer_cls
=
no_split_modules
,
)
# we use float32 for fsdp but autocast during training
mixed_precision
=
get_mixed_precision
(
master_weight_type
)
if
sharding_strategy
==
"full"
:
sharding_strategy
=
ShardingStrategy
.
FULL_SHARD
elif
sharding_strategy
==
"hybrid_full"
:
sharding_strategy
=
ShardingStrategy
.
HYBRID_SHARD
elif
sharding_strategy
==
"none"
:
sharding_strategy
=
ShardingStrategy
.
NO_SHARD
auto_wrap_policy
=
None
elif
sharding_strategy
==
"hybrid_zero2"
:
sharding_strategy
=
ShardingStrategy
.
_HYBRID_SHARD_ZERO2
device_id
=
torch
.
cuda
.
current_device
()
cpu_offload
=
(
torch
.
distributed
.
fsdp
.
CPUOffload
(
offload_params
=
True
)
if
cpu_offload
else
None
)
fsdp_kwargs
=
{
"auto_wrap_policy"
:
auto_wrap_policy
,
"mixed_precision"
:
mixed_precision
,
"sharding_strategy"
:
sharding_strategy
,
"device_id"
:
device_id
,
"limit_all_gathers"
:
True
,
"cpu_offload"
:
cpu_offload
,
}
# Add LoRA-specific settings when LoRA is enabled
if
use_lora
:
fsdp_kwargs
.
update
({
"use_orig_params"
:
False
,
# Required for LoRA memory savings
"sync_module_states"
:
True
,
})
return
fsdp_kwargs
,
no_split_modules
def
get_discriminator_fsdp_kwargs
(
master_weight_type
=
"fp32"
):
auto_wrap_policy
=
None
# Use existing mixed precision settings
mixed_precision
=
get_mixed_precision
(
master_weight_type
)
sharding_strategy
=
ShardingStrategy
.
NO_SHARD
device_id
=
torch
.
cuda
.
current_device
()
fsdp_kwargs
=
{
"auto_wrap_policy"
:
auto_wrap_policy
,
"mixed_precision"
:
mixed_precision
,
"sharding_strategy"
:
sharding_strategy
,
"device_id"
:
device_id
,
"limit_all_gathers"
:
True
,
}
return
fsdp_kwargs
Prev
1
…
4
5
6
7
8
9
10
11
12
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment