Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
STAR
Commits
1f5da520
Commit
1f5da520
authored
Dec 05, 2025
by
yangzhong
Browse files
git init
parents
Pipeline
#3144
failed with stages
in 0 seconds
Changes
326
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2937 additions
and
0 deletions
+2937
-0
utils_data/opensora/datasets/high_order/utils_noise.py
utils_data/opensora/datasets/high_order/utils_noise.py
+218
-0
utils_data/opensora/datasets/high_order/utils_resize.py
utils_data/opensora/datasets/high_order/utils_resize.py
+31
-0
utils_data/opensora/datasets/utils.py
utils_data/opensora/datasets/utils.py
+332
-0
utils_data/opensora/datasets/video_transforms.py
utils_data/opensora/datasets/video_transforms.py
+609
-0
utils_data/opensora/datasets/wavelet_color_fix.py
utils_data/opensora/datasets/wavelet_color_fix.py
+119
-0
utils_data/opensora/models/__init__.py
utils_data/opensora/models/__init__.py
+6
-0
utils_data/opensora/models/__pycache__/__init__.cpython-39.pyc
..._data/opensora/models/__pycache__/__init__.cpython-39.pyc
+0
-0
utils_data/opensora/models/dit/__init__.py
utils_data/opensora/models/dit/__init__.py
+1
-0
utils_data/opensora/models/dit/__pycache__/__init__.cpython-39.pyc
...a/opensora/models/dit/__pycache__/__init__.cpython-39.pyc
+0
-0
utils_data/opensora/models/dit/__pycache__/dit.cpython-39.pyc
...s_data/opensora/models/dit/__pycache__/dit.cpython-39.pyc
+0
-0
utils_data/opensora/models/dit/dit.py
utils_data/opensora/models/dit/dit.py
+288
-0
utils_data/opensora/models/latte/__init__.py
utils_data/opensora/models/latte/__init__.py
+1
-0
utils_data/opensora/models/latte/__pycache__/__init__.cpython-39.pyc
...opensora/models/latte/__pycache__/__init__.cpython-39.pyc
+0
-0
utils_data/opensora/models/latte/__pycache__/latte.cpython-39.pyc
...ta/opensora/models/latte/__pycache__/latte.cpython-39.pyc
+0
-0
utils_data/opensora/models/latte/latte.py
utils_data/opensora/models/latte/latte.py
+112
-0
utils_data/opensora/models/layers/__init__.py
utils_data/opensora/models/layers/__init__.py
+0
-0
utils_data/opensora/models/layers/__pycache__/__init__.cpython-39.pyc
...pensora/models/layers/__pycache__/__init__.cpython-39.pyc
+0
-0
utils_data/opensora/models/layers/__pycache__/blocks.cpython-39.pyc
.../opensora/models/layers/__pycache__/blocks.cpython-39.pyc
+0
-0
utils_data/opensora/models/layers/__pycache__/timm_uvit.cpython-39.pyc
...ensora/models/layers/__pycache__/timm_uvit.cpython-39.pyc
+0
-0
utils_data/opensora/models/layers/blocks.py
utils_data/opensora/models/layers/blocks.py
+1220
-0
No files found.
utils_data/opensora/datasets/high_order/utils_noise.py
0 → 100644
View file @
1f5da520
import
matplotlib.pyplot
as
plt
import
torch
import
numpy
as
np
from
PIL
import
Image
import
cv2
import
torchvision.transforms
"""
1. Gaussian noise
generate_gaussian_noise
random_generate_gaussian_noise
random_add_gaussian_noise
add_gaussian_noise
2. Poisson noise
random_add_poisson_noise
random_generate_poisson_noise
generate_poisson_noise
add_poisson_noise
"""
'''
generate_gaussian_noise
add_gaussian_noise
generate_gaussian_noise_pt
add_gaussian_noise_pt
random_generate_gaussian_noise
random_add_gaussian_noise
random_generate_gaussian_noise_pt
random_add_gaussian_noise_pt
'''
# -------------------------------------------------------------------- #
# --------------------random_add_gaussian_noise----------------------- #
# -------------------------------------------------------------------- #
def
random_add_gaussian_noise_pt
(
img
,
sigma_range
=
(
0
,
1.0
),
gray_prob
=
0
,
clip
=
True
,
rounds
=
False
):
noise
=
random_generate_gaussian_noise_pt
(
img
,
sigma_range
,
gray_prob
)
out
=
img
+
noise
if
clip
and
rounds
:
out
=
torch
.
clamp
((
out
*
255.0
).
round
(),
0
,
255
)
/
255.
elif
clip
:
out
=
torch
.
clamp
(
out
,
0
,
1
)
elif
rounds
:
out
=
(
out
*
255.0
).
round
()
/
255.
return
out
def
random_generate_gaussian_noise_pt
(
img
,
sigma_range
=
(
0
,
10
),
gray_prob
=
0
):
sigma
=
torch
.
rand
(
img
.
size
(
0
),
dtype
=
img
.
dtype
,
device
=
img
.
device
)
*
(
sigma_range
[
1
]
-
sigma_range
[
0
])
+
sigma_range
[
0
]
gray_noise
=
torch
.
rand
(
img
.
size
(
0
),
dtype
=
img
.
dtype
,
device
=
img
.
device
)
gray_noise
=
(
gray_noise
<
gray_prob
).
float
()
return
generate_gaussian_noise_pt
(
img
,
sigma
,
gray_noise
)
def
generate_gaussian_noise_pt
(
img
,
sigma
=
10
,
gray_noise
=
0
):
"""Add Gaussian noise (PyTorch version).
Args:
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
scale (float | Tensor): Noise scale. Default: 1.0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
b
,
_
,
h
,
w
=
img
.
size
()
if
not
isinstance
(
sigma
,
(
float
,
int
)):
sigma
=
sigma
.
view
(
img
.
size
(
0
),
1
,
1
,
1
)
if
isinstance
(
gray_noise
,
(
float
,
int
)):
cal_gray_noise
=
gray_noise
>
0
else
:
gray_noise
=
gray_noise
.
view
(
b
,
1
,
1
,
1
)
cal_gray_noise
=
torch
.
sum
(
gray_noise
)
>
0
if
cal_gray_noise
:
noise_gray
=
torch
.
randn
(
*
img
.
size
()[
2
:
4
],
dtype
=
img
.
dtype
,
device
=
img
.
device
)
*
sigma
/
255.
noise_gray
=
noise_gray
.
view
(
b
,
1
,
h
,
w
)
# always calculate color noise
noise
=
torch
.
randn
(
*
img
.
size
(),
dtype
=
img
.
dtype
,
device
=
img
.
device
)
*
sigma
/
255.
if
cal_gray_noise
:
noise
=
noise
*
(
1
-
gray_noise
)
+
noise_gray
*
gray_noise
return
noise
def
add_gaussian_noise_pt
(
img
,
sigma
=
10
,
gray_noise
=
0
,
clip
=
True
,
rounds
=
False
):
"""Add Gaussian noise (PyTorch version).
Args:
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
scale (float | Tensor): Noise scale. Default: 1.0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
noise
=
generate_gaussian_noise_pt
(
img
,
sigma
,
gray_noise
)
out
=
img
+
noise
if
clip
and
rounds
:
out
=
torch
.
clamp
((
out
*
255.0
).
round
(),
0
,
255
)
/
255.
elif
clip
:
out
=
torch
.
clamp
(
out
,
0
,
1
)
elif
rounds
:
out
=
(
out
*
255.0
).
round
()
/
255.
return
out
# -------------------------------------------------------------------- #
# --------------------random_add_poisson_noise------------------------ #
# -------------------------------------------------------------------- #
def
random_add_poisson_noise_pt
(
img
,
scale_range
=
(
0
,
1.0
),
gray_prob
=
0
,
clip
=
True
,
rounds
=
False
):
noise
=
random_generate_poisson_noise_pt
(
img
,
scale_range
,
gray_prob
)
out
=
img
+
noise
if
clip
and
rounds
:
out
=
torch
.
clamp
((
out
*
255.0
).
round
(),
0
,
255
)
/
255.
elif
clip
:
out
=
torch
.
clamp
(
out
,
0
,
1
)
elif
rounds
:
out
=
(
out
*
255.0
).
round
()
/
255.
return
out
def
random_generate_poisson_noise_pt
(
img
,
scale_range
=
(
0
,
1.0
),
gray_prob
=
0
):
scale
=
torch
.
rand
(
img
.
size
(
0
),
dtype
=
img
.
dtype
,
device
=
img
.
device
)
*
(
scale_range
[
1
]
-
scale_range
[
0
])
+
scale_range
[
0
]
gray_noise
=
torch
.
rand
(
img
.
size
(
0
),
dtype
=
img
.
dtype
,
device
=
img
.
device
)
gray_noise
=
(
gray_noise
<
gray_prob
).
float
()
return
generate_poisson_noise_pt
(
img
,
scale
,
gray_noise
)
def
generate_poisson_noise_pt
(
img
,
scale
=
1.0
,
gray_noise
=
0
):
"""Generate a batch of poisson noise (PyTorch version)
Args:
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
Default: 1.0.
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
0 for False, 1 for True. Default: 0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
b
,
_
,
h
,
w
=
img
.
size
()
if
isinstance
(
gray_noise
,
(
float
,
int
)):
cal_gray_noise
=
gray_noise
>
0
else
:
gray_noise
=
gray_noise
.
view
(
b
,
1
,
1
,
1
)
cal_gray_noise
=
torch
.
sum
(
gray_noise
)
>
0
if
cal_gray_noise
:
#img_gray = rgb_to_grayscale(img, num_output_channels=1)
img_gray
=
img
[:,
0
,:,:]
#size: BHW
img_gray
=
torch
.
unsqueeze
(
img_gray
,
1
)
# round and clip image for counting vals correctly
img_gray
=
torch
.
clamp
((
img_gray
*
255.0
).
round
(),
0
,
255
)
/
255.
# use for-loop to get the unique values for each sample
vals_list
=
[
len
(
torch
.
unique
(
img_gray
[
i
,
:,
:,
:]))
for
i
in
range
(
b
)]
vals_list
=
[
2
**
np
.
ceil
(
np
.
log2
(
vals
))
for
vals
in
vals_list
]
vals
=
img_gray
.
new_tensor
(
vals_list
).
view
(
b
,
1
,
1
,
1
)
out
=
torch
.
poisson
(
img_gray
*
vals
)
/
vals
noise_gray
=
out
-
img_gray
noise_gray
=
noise_gray
.
expand
(
b
,
3
,
h
,
w
)
# always calculate color noise
# round and clip image for counting vals correctly
img
=
torch
.
clamp
((
img
*
255.0
).
round
(),
0
,
255
)
/
255.
# use for-loop to get the unique values for each sample
vals_list
=
[
len
(
torch
.
unique
(
img
[
i
,
:,
:,
:]))
for
i
in
range
(
b
)]
vals_list
=
[
2
**
np
.
ceil
(
np
.
log2
(
vals
))
for
vals
in
vals_list
]
vals
=
img
.
new_tensor
(
vals_list
).
view
(
b
,
1
,
1
,
1
)
out
=
torch
.
poisson
(
img
*
vals
)
/
vals
noise
=
out
-
img
if
cal_gray_noise
:
noise
=
noise
*
(
1
-
gray_noise
)
+
noise_gray
*
gray_noise
if
not
isinstance
(
scale
,
(
float
,
int
)):
scale
=
scale
.
view
(
b
,
1
,
1
,
1
)
return
noise
*
scale
def
add_poisson_noise_pt
(
img
,
scale
=
1.0
,
clip
=
True
,
rounds
=
False
,
gray_noise
=
0
):
"""Add poisson noise to a batch of images (PyTorch version).
Args:
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
Default: 1.0.
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
0 for False, 1 for True. Default: 0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
noise
=
generate_poisson_noise_pt
(
img
,
scale
,
gray_noise
)
out
=
img
+
noise
if
clip
and
rounds
:
out
=
torch
.
clamp
((
out
*
255.0
).
round
(),
0
,
255
)
/
255.
elif
clip
:
out
=
torch
.
clamp
(
out
,
0
,
1
)
elif
rounds
:
out
=
(
out
*
255.0
).
round
()
/
255.
return
out
if
__name__
==
'__main__'
:
gaussian_noise_prob2
=
0.5
noise_range2
=
[
1
,
25
]
poisson_scale_range2
=
[
0.05
,
2.5
]
gray_noise_prob2
=
0.4
# img=cv2.imread('../qj.png')
# img=np.float32(img / 255.)
# noise=random_generate_poisson_noise_pt(img,noise_range2,gray_noise_prob2)
# img_noise=random_add_gaussian_noise_pt(img,noise_range2,gray_noise_prob2)
# print(noise.shape)
#
# img_noise=np.uint8((img_noise.clip(0,1)*255.).round())
img
=
Image
.
open
(
'../dog.jpg'
)
img
=
torchvision
.
transforms
.
ToTensor
()(
img
)
img
=
img
.
unsqueeze
(
0
)
img_noise
=
random_add_poisson_noise_pt
(
img
,
poisson_scale_range2
,
gray_noise_prob2
)
img_noise
=
img_noise
.
squeeze
(
0
).
permute
(
1
,
2
,
0
).
detach
().
numpy
()
plt
.
imshow
(
img_noise
)
plt
.
show
()
\ No newline at end of file
utils_data/opensora/datasets/high_order/utils_resize.py
0 → 100644
View file @
1f5da520
import
torch
from
torch.nn
import
functional
as
F
import
cv2
import
random
import
numpy
as
np
'''
enum InterpolationFlags
{
``'bilinear'`` | ``'bicubic'`` | ``'area'``
};
'''
def
random_resizing
(
image
,
updown_type
,
resize_prob
,
mode_list
,
resize_range
):
b
,
c
,
h
,
w
=
image
.
shape
updown_type
=
random
.
choices
(
updown_type
,
resize_prob
)[
0
]
#choices返回list ["up"],所以要通过 [0] 取list第一个元素
mode
=
random
.
choice
(
mode_list
)
if
updown_type
==
"up"
:
scale
=
np
.
random
.
uniform
(
1
,
resize_range
[
1
])
elif
updown_type
==
"down"
:
scale
=
np
.
random
.
uniform
(
resize_range
[
0
],
1
)
else
:
scale
=
1
image
=
F
.
interpolate
(
image
,
scale_factor
=
scale
,
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
]))
#image = cv2.resize(image, (w, h), interpolation=flags)
image
=
torch
.
clamp
(
image
,
0.0
,
1.0
)
return
image
utils_data/opensora/datasets/utils.py
0 → 100644
View file @
1f5da520
import
os
import
re
from
typing
import
Iterator
,
Optional
from
torch.distributed
import
ProcessGroup
import
numpy
as
np
import
pandas
as
pd
import
requests
import
torch
import
cv2
import
torchvision
import
torchvision.transforms
as
transforms
from
torch.utils.data
import
DataLoader
,
Dataset
from
PIL
import
Image
from
torchvision.datasets.folder
import
IMG_EXTENSIONS
,
pil_loader
from
torch.distributed.distributed_c10d
import
_get_default_group
from
torch.utils.data.distributed
import
DistributedSampler
from
torchvision.io
import
write_video
from
torchvision.utils
import
save_image
import
random
from
.
import
video_transforms
from
.wavelet_color_fix
import
adain_color_fix
VID_EXTENSIONS
=
(
".mp4"
,
".avi"
,
".mov"
,
".mkv"
)
regex
=
re
.
compile
(
r
"^(?:http|ftp)s?://"
# http:// or https://
r
"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|"
# domain...
r
"localhost|"
# localhost...
r
"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})"
# ...or ip
r
"(?::\d+)?"
# optional port
r
"(?:/?|[/?]\S+)$"
,
re
.
IGNORECASE
,
)
def
is_url
(
url
):
return
re
.
match
(
regex
,
url
)
is
not
None
def
read_file
(
input_path
):
if
input_path
.
endswith
(
".csv"
):
return
pd
.
read_csv
(
input_path
)
elif
input_path
.
endswith
(
".parquet"
):
return
pd
.
read_parquet
(
input_path
)
else
:
raise
NotImplementedError
(
f
"Unsupported file format:
{
input_path
}
"
)
def
download_url
(
input_path
):
output_dir
=
"cache"
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
base_name
=
os
.
path
.
basename
(
input_path
)
output_path
=
os
.
path
.
join
(
output_dir
,
base_name
)
img_data
=
requests
.
get
(
input_path
).
content
with
open
(
output_path
,
"wb"
)
as
handler
:
handler
.
write
(
img_data
)
print
(
f
"URL
{
input_path
}
downloaded to
{
output_path
}
"
)
return
output_path
def
temporal_random_crop
(
vframes
,
num_frames
,
frame_interval
):
temporal_sample
=
video_transforms
.
TemporalRandomCrop
(
num_frames
*
frame_interval
)
total_frames
=
len
(
vframes
)
start_frame_ind
,
end_frame_ind
=
temporal_sample
(
total_frames
)
assert
end_frame_ind
-
start_frame_ind
>=
num_frames
frame_indice
=
np
.
linspace
(
start_frame_ind
,
end_frame_ind
-
1
,
num_frames
,
dtype
=
int
)
video
=
vframes
[
frame_indice
]
return
video
def
compute_bidirectional_optical_flow
(
video_frames
):
video_frames
=
video_frames
.
permute
(
0
,
2
,
3
,
1
).
numpy
()
# T C H W -> T H W C
T
,
H
,
W
,
_
=
video_frames
.
shape
bidirectional_flow
=
torch
.
zeros
((
2
,
T
-
1
,
H
,
W
))
for
t
in
range
(
T
-
1
):
prev_frame
=
cv2
.
cvtColor
(
video_frames
[
t
],
cv2
.
COLOR_RGB2GRAY
)
next_frame
=
cv2
.
cvtColor
(
video_frames
[
t
+
1
],
cv2
.
COLOR_RGB2GRAY
)
# 计算前向光流
flow_forward
=
cv2
.
calcOpticalFlowFarneback
(
prev_frame
,
next_frame
,
None
,
0.5
,
3
,
15
,
3
,
5
,
1.2
,
0
)
# 计算反向光流
flow_backward
=
cv2
.
calcOpticalFlowFarneback
(
next_frame
,
prev_frame
,
None
,
0.5
,
3
,
15
,
3
,
5
,
1.2
,
0
)
# 合并前向和反向光流图
bidirectional_flow
[:,
t
]
=
torch
.
from_numpy
((
flow_forward
+
flow_backward
).
reshape
(
2
,
H
,
W
))
return
bidirectional_flow
# 定义模糊函数
def
blur_video
(
video
,
kernel_size
=
(
21
,
21
),
sigma
=
21
):
"""
对视频的每一帧进行高斯模糊处理
Args:
video (torch.Tensor): 输入视频,维度为 [T, C, H, W]
kernel_size (tuple): 模糊核大小,默认为 (5, 5)
sigma (float): 高斯核标准差,默认为 0
Returns:
torch.Tensor: 处理后的视频
"""
blurred_frames
=
[]
for
frame
in
video
:
# 转换成 numpy 格式,大小为 (H, W, C)
frame_np
=
frame
.
permute
(
1
,
2
,
0
).
numpy
()
# 使用 OpenCV 进行高斯模糊处理
blurred_frame
=
cv2
.
GaussianBlur
(
frame_np
,
kernel_size
,
sigma
)
# 转换回 PyTorch 格式,大小为 (C, H, W)
blurred_frame
=
torch
.
from_numpy
(
blurred_frame
).
permute
(
2
,
0
,
1
)
blurred_frames
.
append
(
blurred_frame
)
# 拼接处理后的帧成为视频,维度为 [T, C, H, W]
return
torch
.
stack
(
blurred_frames
)
def
get_transforms_video
(
name
=
"center"
,
image_size
=
(
256
,
256
)):
if
name
is
None
:
return
None
elif
name
==
"center"
:
assert
image_size
[
0
]
==
image_size
[
1
],
"image_size must be square for center crop"
transform_video
=
transforms
.
Compose
(
[
video_transforms
.
ToTensorVideo
(),
# TCHW
# video_transforms.RandomHorizontalFlipVideo(),
video_transforms
.
UCFCenterCropVideo
(
image_size
[
0
]),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
elif
name
==
"resize_crop"
:
transform_video
=
transforms
.
Compose
(
[
video_transforms
.
ToTensorVideo
(),
# TCHW
video_transforms
.
ResizeCrop
(
image_size
),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
elif
name
==
"direct_crop"
:
transform_video
=
transforms
.
Compose
(
[
video_transforms
.
ToTensorVideo
(),
# TCHW
video_transforms
.
RandomCrop
(
image_size
),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
else
:
raise
NotImplementedError
(
f
"Transform
{
name
}
not implemented"
)
return
transform_video
def
get_transforms_image
(
name
=
"center"
,
image_size
=
(
256
,
256
)):
if
name
is
None
:
return
None
elif
name
==
"center"
:
assert
image_size
[
0
]
==
image_size
[
1
],
"Image size must be square for center crop"
transform
=
transforms
.
Compose
(
[
transforms
.
Lambda
(
lambda
pil_image
:
center_crop_arr
(
pil_image
,
image_size
[
0
])),
# transforms.RandomHorizontalFlip(),
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
elif
name
==
"resize_crop"
:
transform
=
transforms
.
Compose
(
[
transforms
.
Lambda
(
lambda
pil_image
:
resize_crop_to_fill
(
pil_image
,
image_size
)),
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
else
:
raise
NotImplementedError
(
f
"Transform
{
name
}
not implemented"
)
return
transform
def
read_image_from_path
(
path
,
transform
=
None
,
transform_name
=
"center"
,
num_frames
=
1
,
image_size
=
(
256
,
256
)):
image
=
pil_loader
(
path
)
if
transform
is
None
:
transform
=
get_transforms_image
(
image_size
=
image_size
,
name
=
transform_name
)
image
=
transform
(
image
)
video
=
image
.
unsqueeze
(
0
).
repeat
(
num_frames
,
1
,
1
,
1
)
video
=
video
.
permute
(
1
,
0
,
2
,
3
)
return
video
def
read_video_from_path
(
path
,
transform
=
None
,
transform_name
=
"center"
,
image_size
=
(
256
,
256
)):
vframes
,
aframes
,
info
=
torchvision
.
io
.
read_video
(
filename
=
path
,
pts_unit
=
"sec"
,
output_format
=
"TCHW"
)
if
transform
is
None
:
transform
=
get_transforms_video
(
image_size
=
image_size
,
name
=
transform_name
)
video
=
transform
(
vframes
)
# T C H W
video
=
video
.
permute
(
1
,
0
,
2
,
3
)
return
video
def
read_from_path
(
path
,
image_size
,
transform_name
=
"center"
):
if
is_url
(
path
):
path
=
download_url
(
path
)
ext
=
os
.
path
.
splitext
(
path
)[
-
1
].
lower
()
if
ext
.
lower
()
in
VID_EXTENSIONS
:
return
read_video_from_path
(
path
,
image_size
=
image_size
,
transform_name
=
transform_name
)
else
:
assert
ext
.
lower
()
in
IMG_EXTENSIONS
,
f
"Unsupported file format:
{
ext
}
"
return
read_image_from_path
(
path
,
image_size
=
image_size
,
transform_name
=
transform_name
)
def
save_sample
(
x
,
fps
=
8
,
save_path
=
None
,
normalize
=
True
,
value_range
=
(
-
1
,
1
),
force_video
=
False
,
align_method
=
None
,
validation_video
=
None
):
os
.
makedirs
(
os
.
path
.
dirname
(
save_path
),
exist_ok
=
True
)
"""
Args:
x (Tensor): shape [C, T, H, W]
"""
assert
x
.
ndim
==
4
if
not
force_video
and
x
.
shape
[
1
]
==
1
:
# T = 1: save as image
save_path
+=
".png"
x
=
x
.
squeeze
(
1
)
save_image
([
x
],
save_path
,
normalize
=
normalize
,
value_range
=
value_range
)
else
:
save_path
+=
".mp4"
if
normalize
:
low
,
high
=
value_range
x
.
clamp_
(
min
=
low
,
max
=
high
)
x
.
sub_
(
low
).
div_
(
max
(
high
-
low
,
1e-5
))
if
align_method
:
x
=
adain_color_fix
(
x
,
validation_video
)
x
=
x
.
mul
(
255
).
add_
(
0.5
).
clamp_
(
0
,
255
).
permute
(
1
,
2
,
3
,
0
).
to
(
"cpu"
,
torch
.
uint8
)
write_video
(
save_path
,
x
,
fps
=
int
(
fps
),
video_codec
=
"h264"
)
# print(f"Saved to {save_path}")
return
save_path
def
center_crop_arr
(
pil_image
,
image_size
):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while
min
(
*
pil_image
.
size
)
>=
2
*
image_size
:
pil_image
=
pil_image
.
resize
(
tuple
(
x
//
2
for
x
in
pil_image
.
size
),
resample
=
Image
.
BOX
)
scale
=
image_size
/
min
(
*
pil_image
.
size
)
pil_image
=
pil_image
.
resize
(
tuple
(
round
(
x
*
scale
)
for
x
in
pil_image
.
size
),
resample
=
Image
.
BICUBIC
)
arr
=
np
.
array
(
pil_image
)
crop_y
=
(
arr
.
shape
[
0
]
-
image_size
)
//
2
crop_x
=
(
arr
.
shape
[
1
]
-
image_size
)
//
2
return
Image
.
fromarray
(
arr
[
crop_y
:
crop_y
+
image_size
,
crop_x
:
crop_x
+
image_size
])
class
StatefulDistributedSampler
(
DistributedSampler
):
def
__init__
(
self
,
dataset
:
Dataset
,
num_replicas
:
Optional
[
int
]
=
None
,
rank
:
Optional
[
int
]
=
None
,
shuffle
:
bool
=
True
,
seed
:
int
=
0
,
drop_last
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
dataset
,
num_replicas
,
rank
,
shuffle
,
seed
,
drop_last
)
self
.
start_index
:
int
=
0
def
__iter__
(
self
)
->
Iterator
:
iterator
=
super
().
__iter__
()
indices
=
list
(
iterator
)
indices
=
indices
[
self
.
start_index
:]
return
iter
(
indices
)
def
__len__
(
self
)
->
int
:
return
self
.
num_samples
-
self
.
start_index
def
set_start_index
(
self
,
start_index
:
int
)
->
None
:
self
.
start_index
=
start_index
def
prepare_dataloader
(
dataset
,
batch_size
,
shuffle
=
False
,
seed
=
1024
,
drop_last
=
False
,
pin_memory
=
False
,
num_workers
=
0
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
**
kwargs
,
):
r
"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs
=
kwargs
.
copy
()
process_group
=
process_group
or
_get_default_group
()
sampler
=
StatefulDistributedSampler
(
dataset
,
num_replicas
=
process_group
.
size
(),
rank
=
process_group
.
rank
(),
shuffle
=
shuffle
)
# Deterministic dataloader
def
seed_worker
(
worker_id
):
worker_seed
=
seed
np
.
random
.
seed
(
worker_seed
)
torch
.
manual_seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
return
DataLoader
(
dataset
,
batch_size
=
batch_size
,
sampler
=
sampler
,
worker_init_fn
=
seed_worker
,
drop_last
=
drop_last
,
pin_memory
=
pin_memory
,
num_workers
=
num_workers
,
**
_kwargs
,
)
\ No newline at end of file
utils_data/opensora/datasets/video_transforms.py
0 → 100644
View file @
1f5da520
# Copyright 2024 Vchitect/Latte
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# Modified from Latte
# - This file is adapted from https://github.com/Vchitect/Latte/blob/main/datasets/video_transforms.py
import
numbers
import
random
import
torch
def
_is_tensor_video_clip
(
clip
):
if
not
torch
.
is_tensor
(
clip
):
raise
TypeError
(
"clip should be Tensor. Got %s"
%
type
(
clip
))
if
not
clip
.
ndimension
()
==
4
:
raise
ValueError
(
"clip should be 4D. Got %dD"
%
clip
.
dim
())
return
True
def
center_crop_arr
(
pil_image
,
image_size
):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while
min
(
*
pil_image
.
size
)
>=
2
*
image_size
:
pil_image
=
pil_image
.
resize
(
tuple
(
x
//
2
for
x
in
pil_image
.
size
),
resample
=
Image
.
BOX
)
scale
=
image_size
/
min
(
*
pil_image
.
size
)
pil_image
=
pil_image
.
resize
(
tuple
(
round
(
x
*
scale
)
for
x
in
pil_image
.
size
),
resample
=
Image
.
BICUBIC
)
arr
=
np
.
array
(
pil_image
)
crop_y
=
(
arr
.
shape
[
0
]
-
image_size
)
//
2
crop_x
=
(
arr
.
shape
[
1
]
-
image_size
)
//
2
return
Image
.
fromarray
(
arr
[
crop_y
:
crop_y
+
image_size
,
crop_x
:
crop_x
+
image_size
])
def
crop
(
clip
,
i
,
j
,
h
,
w
):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if
len
(
clip
.
size
())
!=
4
:
raise
ValueError
(
"clip should be a 4D tensor"
)
return
clip
[...,
i
:
i
+
h
,
j
:
j
+
w
]
def
random_crop
(
clip
,
crop_size
):
"""
Args:
clip (torch.Tensor): Video clip to be cropped. Size is (T, C, H, W)
crop_size (tuple): Desired output size (h, w)
Returns:
torch.Tensor: Cropped video of size (T, C, h, w)
"""
if
len
(
clip
.
size
())
!=
4
:
raise
ValueError
(
"clip should be a 4D tensor"
)
_
,
_
,
H
,
W
=
clip
.
shape
th
,
tw
=
crop_size
if
th
>
H
or
tw
>
W
:
raise
ValueError
(
"Crop size should be smaller than video dimensions"
)
i
=
torch
.
randint
(
0
,
H
-
th
+
1
,
size
=
(
1
,)).
item
()
j
=
torch
.
randint
(
0
,
W
-
tw
+
1
,
size
=
(
1
,)).
item
()
return
crop
(
clip
,
i
,
j
,
th
,
tw
)
def
resize
(
clip
,
target_size
,
interpolation_mode
):
if
len
(
target_size
)
!=
2
:
raise
ValueError
(
f
"target size should be tuple (height, width), instead got
{
target_size
}
"
)
return
torch
.
nn
.
functional
.
interpolate
(
clip
,
size
=
target_size
,
mode
=
interpolation_mode
,
align_corners
=
False
)
def
resize_scale
(
clip
,
target_size
,
interpolation_mode
):
if
len
(
target_size
)
!=
2
:
raise
ValueError
(
f
"target size should be tuple (height, width), instead got
{
target_size
}
"
)
H
,
W
=
clip
.
size
(
-
2
),
clip
.
size
(
-
1
)
scale_
=
target_size
[
0
]
/
min
(
H
,
W
)
return
torch
.
nn
.
functional
.
interpolate
(
clip
,
scale_factor
=
scale_
,
mode
=
interpolation_mode
,
align_corners
=
False
)
def
resized_crop
(
clip
,
i
,
j
,
h
,
w
,
size
,
interpolation_mode
=
"bilinear"
):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
"""
if
not
_is_tensor_video_clip
(
clip
):
raise
ValueError
(
"clip should be a 4D torch.tensor"
)
clip
=
crop
(
clip
,
i
,
j
,
h
,
w
)
clip
=
resize
(
clip
,
size
,
interpolation_mode
)
return
clip
def
center_crop
(
clip
,
crop_size
):
if
not
_is_tensor_video_clip
(
clip
):
raise
ValueError
(
"clip should be a 4D torch.tensor"
)
h
,
w
=
clip
.
size
(
-
2
),
clip
.
size
(
-
1
)
th
,
tw
=
crop_size
if
h
<
th
or
w
<
tw
:
clip
=
torch
.
nn
.
functional
.
interpolate
(
clip
,
size
=
(
th
,
tw
),
mode
=
"bilinear"
,
align_corners
=
False
)
h
,
w
=
clip
.
size
(
-
2
),
clip
.
size
(
-
1
)
#raise ValueError("height and width must be no smaller than crop_size")
i
=
int
(
round
((
h
-
th
)
/
2.0
))
j
=
int
(
round
((
w
-
tw
)
/
2.0
))
return
crop
(
clip
,
i
,
j
,
th
,
tw
)
def
center_crop_using_short_edge
(
clip
):
if
not
_is_tensor_video_clip
(
clip
):
raise
ValueError
(
"clip should be a 4D torch.tensor"
)
h
,
w
=
clip
.
size
(
-
2
),
clip
.
size
(
-
1
)
if
h
<
w
:
th
,
tw
=
h
,
h
i
=
0
j
=
int
(
round
((
w
-
tw
)
/
2.0
))
else
:
th
,
tw
=
w
,
w
i
=
int
(
round
((
h
-
th
)
/
2.0
))
j
=
0
return
crop
(
clip
,
i
,
j
,
th
,
tw
)
def
random_shift_crop
(
clip
):
"""
Slide along the long edge, with the short edge as crop size
"""
if
not
_is_tensor_video_clip
(
clip
):
raise
ValueError
(
"clip should be a 4D torch.tensor"
)
h
,
w
=
clip
.
size
(
-
2
),
clip
.
size
(
-
1
)
if
h
<=
w
:
short_edge
=
h
else
:
short_edge
=
w
th
,
tw
=
short_edge
,
short_edge
i
=
torch
.
randint
(
0
,
h
-
th
+
1
,
size
=
(
1
,)).
item
()
j
=
torch
.
randint
(
0
,
w
-
tw
+
1
,
size
=
(
1
,)).
item
()
return
crop
(
clip
,
i
,
j
,
th
,
tw
)
def
to_tensor
(
clip
):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
_is_tensor_video_clip
(
clip
)
if
not
clip
.
dtype
==
torch
.
uint8
:
raise
TypeError
(
"clip tensor should have data type uint8. Got %s"
%
str
(
clip
.
dtype
))
# return clip.float().permute(3, 0, 1, 2) / 255.0
return
clip
.
float
()
/
255.0
def
normalize
(
clip
,
mean
,
std
,
inplace
=
False
):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (T, C, H, W)
"""
if
not
_is_tensor_video_clip
(
clip
):
raise
ValueError
(
"clip should be a 4D torch.tensor"
)
if
not
inplace
:
clip
=
clip
.
clone
()
mean
=
torch
.
as_tensor
(
mean
,
dtype
=
clip
.
dtype
,
device
=
clip
.
device
)
# print(mean)
std
=
torch
.
as_tensor
(
std
,
dtype
=
clip
.
dtype
,
device
=
clip
.
device
)
clip
.
sub_
(
mean
[:,
None
,
None
,
None
]).
div_
(
std
[:,
None
,
None
,
None
])
return
clip
def
hflip
(
clip
):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
Returns:
flipped clip (torch.tensor): Size is (T, C, H, W)
"""
if
not
_is_tensor_video_clip
(
clip
):
raise
ValueError
(
"clip should be a 4D torch.tensor"
)
return
clip
.
flip
(
-
1
)
class
RandomCropVideo
:
def
__init__
(
self
,
size
):
if
isinstance
(
size
,
numbers
.
Number
):
self
.
size
=
(
int
(
size
),
int
(
size
))
else
:
self
.
size
=
size
def
__call__
(
self
,
clip
):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: randomly cropped video clip.
size is (T, C, OH, OW)
"""
i
,
j
,
h
,
w
=
self
.
get_params
(
clip
)
return
crop
(
clip
,
i
,
j
,
h
,
w
)
def
get_params
(
self
,
clip
):
h
,
w
=
clip
.
shape
[
-
2
:]
th
,
tw
=
self
.
size
if
h
<
th
or
w
<
tw
:
raise
ValueError
(
f
"Required crop size
{
(
th
,
tw
)
}
is larger than input image size
{
(
h
,
w
)
}
"
)
if
w
==
tw
and
h
==
th
:
return
0
,
0
,
h
,
w
i
=
torch
.
randint
(
0
,
h
-
th
+
1
,
size
=
(
1
,)).
item
()
j
=
torch
.
randint
(
0
,
w
-
tw
+
1
,
size
=
(
1
,)).
item
()
return
i
,
j
,
th
,
tw
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(size=
{
self
.
size
}
)"
class
CenterCropResizeVideo
:
"""
First use the short side for cropping length,
center crop video, then resize to the specified size
"""
def
__init__
(
self
,
size
,
interpolation_mode
=
"bilinear"
,
):
if
isinstance
(
size
,
tuple
):
if
len
(
size
)
!=
2
:
raise
ValueError
(
f
"size should be tuple (height, width), instead got
{
size
}
"
)
self
.
size
=
size
else
:
self
.
size
=
(
size
,
size
)
self
.
interpolation_mode
=
interpolation_mode
def
__call__
(
self
,
clip
):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop
=
center_crop_using_short_edge
(
clip
)
clip_center_crop_resize
=
resize
(
clip_center_crop
,
target_size
=
self
.
size
,
interpolation_mode
=
self
.
interpolation_mode
)
return
clip_center_crop_resize
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(size=
{
self
.
size
}
, interpolation_mode=
{
self
.
interpolation_mode
}
"
class
UCFCenterCropVideo
:
"""
First scale to the specified size in equal proportion to the short edge,
then center cropping
"""
def
__init__
(
self
,
size
,
interpolation_mode
=
"bilinear"
,
):
if
isinstance
(
size
,
tuple
):
if
len
(
size
)
!=
2
:
raise
ValueError
(
f
"size should be tuple (height, width), instead got
{
size
}
"
)
self
.
size
=
size
else
:
self
.
size
=
(
size
,
size
)
self
.
interpolation_mode
=
interpolation_mode
def
__call__
(
self
,
clip
):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize
=
resize_scale
(
clip
=
clip
,
target_size
=
self
.
size
,
interpolation_mode
=
self
.
interpolation_mode
)
clip_center_crop
=
center_crop
(
clip_resize
,
self
.
size
)
return
clip_center_crop
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(size=
{
self
.
size
}
, interpolation_mode=
{
self
.
interpolation_mode
}
"
class
KineticsRandomCropResizeVideo
:
"""
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
"""
def
__init__
(
self
,
size
,
interpolation_mode
=
"bilinear"
,
):
if
isinstance
(
size
,
tuple
):
if
len
(
size
)
!=
2
:
raise
ValueError
(
f
"size should be tuple (height, width), instead got
{
size
}
"
)
self
.
size
=
size
else
:
self
.
size
=
(
size
,
size
)
self
.
interpolation_mode
=
interpolation_mode
def
__call__
(
self
,
clip
):
clip_random_crop
=
random_shift_crop
(
clip
)
clip_resize
=
resize
(
clip_random_crop
,
self
.
size
,
self
.
interpolation_mode
)
return
clip_resize
class
CenterCropVideo
:
def
__init__
(
self
,
size
,
interpolation_mode
=
"bilinear"
,
):
if
isinstance
(
size
,
tuple
):
if
len
(
size
)
!=
2
:
raise
ValueError
(
f
"size should be tuple (height, width), instead got
{
size
}
"
)
self
.
size
=
size
else
:
self
.
size
=
(
size
,
size
)
self
.
interpolation_mode
=
interpolation_mode
def
__call__
(
self
,
clip
):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop
=
center_crop
(
clip
,
self
.
size
)
return
clip_center_crop
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(size=
{
self
.
size
}
, interpolation_mode=
{
self
.
interpolation_mode
}
"
class
NormalizeVideo
:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def
__init__
(
self
,
mean
,
std
,
inplace
=
False
):
self
.
mean
=
mean
self
.
std
=
std
self
.
inplace
=
inplace
def
__call__
(
self
,
clip
):
"""
Args:
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
"""
return
normalize
(
clip
,
self
.
mean
,
self
.
std
,
self
.
inplace
)
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(mean=
{
self
.
mean
}
, std=
{
self
.
std
}
, inplace=
{
self
.
inplace
}
)"
class
ToTensorVideo
:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def
__init__
(
self
):
pass
def
__call__
(
self
,
clip
):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
return
to_tensor
(
clip
)
def
__repr__
(
self
)
->
str
:
return
self
.
__class__
.
__name__
class
RandomCrop
:
"""
Perform random cropping on a video tensor of shape (T, C, H, W).
"""
def
__init__
(
self
,
crop_size
):
"""
Args:
crop_size (tuple): Desired output size (h, w)
"""
self
.
crop_size
=
crop_size
def
__call__
(
self
,
clip
):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Video tensor of size (T, C, H, W)
Returns:
torch.tensor: Cropped video tensor of size (T, C, h, w), dtype=torch.float
"""
return
random_crop
(
clip
,
self
.
crop_size
)
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(crop_size=
{
self
.
crop_size
}
)"
class
RandomHorizontalFlipVideo
:
"""
Flip the video clip along the horizontal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def
__init__
(
self
,
p
=
0.5
):
self
.
p
=
p
def
__call__
(
self
,
clip
):
"""
Args:
clip (torch.tensor): Size is (T, C, H, W)
Return:
clip (torch.tensor): Size is (T, C, H, W)
"""
if
random
.
random
()
<
self
.
p
:
clip
=
hflip
(
clip
)
return
clip
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(p=
{
self
.
p
}
)"
# ------------------------------------------------------------
# --------------------- Sampling ---------------------------
# ------------------------------------------------------------
class
TemporalRandomCrop
(
object
):
"""Temporally crop the given frame indices at a random location.
Args:
size (int): Desired length of frames will be seen in the model.
"""
def
__init__
(
self
,
size
):
self
.
size
=
size
def
__call__
(
self
,
total_frames
):
rand_end
=
max
(
0
,
total_frames
-
self
.
size
-
1
)
begin_index
=
random
.
randint
(
0
,
rand_end
)
end_index
=
min
(
begin_index
+
self
.
size
,
total_frames
)
return
begin_index
,
end_index
if
__name__
==
"__main__"
:
import
os
import
numpy
as
np
import
torchvision.io
as
io
from
torchvision
import
transforms
from
torchvision.utils
import
save_image
vframes
,
aframes
,
info
=
io
.
read_video
(
filename
=
"./v_Archery_g01_c03.avi"
,
pts_unit
=
"sec"
,
output_format
=
"TCHW"
)
trans
=
transforms
.
Compose
(
[
ToTensorVideo
(),
RandomHorizontalFlipVideo
(),
UCFCenterCropVideo
(
512
),
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
],
inplace
=
True
),
]
)
target_video_len
=
32
frame_interval
=
1
total_frames
=
len
(
vframes
)
print
(
total_frames
)
temporal_sample
=
TemporalRandomCrop
(
target_video_len
*
frame_interval
)
# Sampling video frames
start_frame_ind
,
end_frame_ind
=
temporal_sample
(
total_frames
)
# print(start_frame_ind)
# print(end_frame_ind)
assert
end_frame_ind
-
start_frame_ind
>=
target_video_len
frame_indice
=
np
.
linspace
(
start_frame_ind
,
end_frame_ind
-
1
,
target_video_len
,
dtype
=
int
)
print
(
frame_indice
)
select_vframes
=
vframes
[
frame_indice
]
print
(
select_vframes
.
shape
)
print
(
select_vframes
.
dtype
)
select_vframes_trans
=
trans
(
select_vframes
)
print
(
select_vframes_trans
.
shape
)
print
(
select_vframes_trans
.
dtype
)
select_vframes_trans_int
=
((
select_vframes_trans
*
0.5
+
0.5
)
*
255
).
to
(
dtype
=
torch
.
uint8
)
print
(
select_vframes_trans_int
.
dtype
)
print
(
select_vframes_trans_int
.
permute
(
0
,
2
,
3
,
1
).
shape
)
io
.
write_video
(
"./test.avi"
,
select_vframes_trans_int
.
permute
(
0
,
2
,
3
,
1
),
fps
=
8
)
for
i
in
range
(
target_video_len
):
save_image
(
select_vframes_trans
[
i
],
os
.
path
.
join
(
"./test000"
,
"%04d.png"
%
i
),
normalize
=
True
,
value_range
=
(
-
1
,
1
)
)
class
ResizeCrop
:
def
__init__
(
self
,
size
):
if
isinstance
(
size
,
numbers
.
Number
):
self
.
size
=
(
int
(
size
),
int
(
size
))
else
:
self
.
size
=
size
def
__call__
(
self
,
clip
):
clip
=
resize_crop_to_fill
(
clip
,
self
.
size
)
return
clip
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(size=
{
self
.
size
}
)"
def
_is_tensor_video_clip
(
clip
):
if
not
torch
.
is_tensor
(
clip
):
raise
TypeError
(
"clip should be Tensor. Got %s"
%
type
(
clip
))
if
not
clip
.
ndimension
()
==
4
:
raise
ValueError
(
"clip should be 4D. Got %dD"
%
clip
.
dim
())
return
True
def
resize
(
clip
,
target_size
,
interpolation_mode
):
if
len
(
target_size
)
!=
2
:
raise
ValueError
(
f
"target size should be tuple (height, width), instead got
{
target_size
}
"
)
return
torch
.
nn
.
functional
.
interpolate
(
clip
,
size
=
target_size
,
mode
=
interpolation_mode
,
align_corners
=
False
)
def
crop
(
clip
,
i
,
j
,
h
,
w
):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if
len
(
clip
.
size
())
!=
4
:
raise
ValueError
(
"clip should be a 4D tensor"
)
return
clip
[...,
i
:
i
+
h
,
j
:
j
+
w
]
def
resize_crop_to_fill
(
clip
,
target_size
):
if
not
_is_tensor_video_clip
(
clip
):
raise
ValueError
(
"clip should be a 4D torch.tensor"
)
h
,
w
=
clip
.
size
(
-
2
),
clip
.
size
(
-
1
)
th
,
tw
=
target_size
[
0
],
target_size
[
1
]
rh
,
rw
=
th
/
h
,
tw
/
w
if
rh
>
rw
:
sh
,
sw
=
th
,
round
(
w
*
rh
)
clip
=
resize
(
clip
,
(
sh
,
sw
),
"bilinear"
)
i
=
0
j
=
int
(
round
(
sw
-
tw
)
/
2.0
)
else
:
sh
,
sw
=
round
(
h
*
rw
),
tw
clip
=
resize
(
clip
,
(
sh
,
sw
),
"bilinear"
)
i
=
int
(
round
(
sh
-
th
)
/
2.0
)
j
=
0
assert
i
+
th
<=
clip
.
size
(
-
2
)
and
j
+
tw
<=
clip
.
size
(
-
1
)
return
crop
(
clip
,
i
,
j
,
th
,
tw
)
\ No newline at end of file
utils_data/opensora/datasets/wavelet_color_fix.py
0 → 100644
View file @
1f5da520
'''
# --------------------------------------------------------------------------------
# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
# --------------------------------------------------------------------------------
'''
import
torch
from
PIL
import
Image
from
torch
import
Tensor
from
torch.nn
import
functional
as
F
from
einops
import
rearrange
from
torchvision.transforms
import
ToTensor
,
ToPILImage
def
adain_color_fix
(
target
:
Image
,
source
:
Image
):
# torch.Size([3, 16, 256, 256])
# Apply adaptive instance normalization
target
=
rearrange
(
target
,
"C T H W -> T C H W"
)
source
=
rearrange
(
source
,
"C T H W -> T C H W"
)
result_tensor
=
adaptive_instance_normalization
(
target
,
source
)
result_tensor
=
rearrange
(
result_tensor
,
"T C H W -> C T H W"
).
clamp_
(
0.0
,
1.0
)
# Convert tensor back to image
# to_image = ToPILImage()
# result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
return
result_tensor
def
wavelet_color_fix
(
target
:
Image
,
source
:
Image
):
# Convert images to tensors
to_tensor
=
ToTensor
()
target_tensor
=
to_tensor
(
target
).
unsqueeze
(
0
)
source_tensor
=
to_tensor
(
source
).
unsqueeze
(
0
)
# Apply wavelet reconstruction
result_tensor
=
wavelet_reconstruction
(
target_tensor
,
source_tensor
)
# Convert tensor back to image
to_image
=
ToPILImage
()
result_image
=
to_image
(
result_tensor
.
squeeze
(
0
).
clamp_
(
0.0
,
1.0
))
return
result_image
def
calc_mean_std
(
feat
:
Tensor
,
eps
=
1e-5
):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size
=
feat
.
size
()
assert
len
(
size
)
==
4
,
'The input feature should be 4D tensor.'
b
,
c
=
size
[:
2
]
feat_var
=
feat
.
reshape
(
b
,
c
,
-
1
).
var
(
dim
=
2
)
+
eps
feat_std
=
feat_var
.
sqrt
().
reshape
(
b
,
c
,
1
,
1
)
feat_mean
=
feat
.
reshape
(
b
,
c
,
-
1
).
mean
(
dim
=
2
).
reshape
(
b
,
c
,
1
,
1
)
return
feat_mean
,
feat_std
def
adaptive_instance_normalization
(
content_feat
:
Tensor
,
style_feat
:
Tensor
):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size
=
content_feat
.
size
()
style_mean
,
style_std
=
calc_mean_std
(
style_feat
)
content_mean
,
content_std
=
calc_mean_std
(
content_feat
)
normalized_feat
=
(
content_feat
-
content_mean
.
expand
(
size
))
/
content_std
.
expand
(
size
)
return
normalized_feat
*
style_std
.
expand
(
size
)
+
style_mean
.
expand
(
size
)
def
wavelet_blur
(
image
:
Tensor
,
radius
:
int
):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals
=
[
[
0.0625
,
0.125
,
0.0625
],
[
0.125
,
0.25
,
0.125
],
[
0.0625
,
0.125
,
0.0625
],
]
kernel
=
torch
.
tensor
(
kernel_vals
,
dtype
=
image
.
dtype
,
device
=
image
.
device
)
# add channel dimensions to the kernel to make it a 4D tensor
kernel
=
kernel
[
None
,
None
]
# repeat the kernel across all input channels
kernel
=
kernel
.
repeat
(
3
,
1
,
1
,
1
)
image
=
F
.
pad
(
image
,
(
radius
,
radius
,
radius
,
radius
),
mode
=
'replicate'
)
# apply convolution
output
=
F
.
conv2d
(
image
,
kernel
,
groups
=
3
,
dilation
=
radius
)
return
output
def
wavelet_decomposition
(
image
:
Tensor
,
levels
=
5
):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq
=
torch
.
zeros_like
(
image
)
for
i
in
range
(
levels
):
radius
=
2
**
i
low_freq
=
wavelet_blur
(
image
,
radius
)
high_freq
+=
(
image
-
low_freq
)
image
=
low_freq
return
high_freq
,
low_freq
def
wavelet_reconstruction
(
content_feat
:
Tensor
,
style_feat
:
Tensor
):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq
,
content_low_freq
=
wavelet_decomposition
(
content_feat
)
del
content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq
,
style_low_freq
=
wavelet_decomposition
(
style_feat
)
del
style_high_freq
# reconstruct the content feature with the style's high frequency
return
content_high_freq
+
style_low_freq
utils_data/opensora/models/__init__.py
0 → 100644
View file @
1f5da520
from
.dit
import
*
from
.latte
import
*
from
.pixart
import
*
from
.stdit
import
*
from
.text_encoder
import
*
from
.vae
import
*
utils_data/opensora/models/__pycache__/__init__.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/dit/__init__.py
0 → 100644
View file @
1f5da520
from
.dit
import
DiT
,
DiT_XL_2
,
DiT_XL_2x2
utils_data/opensora/models/dit/__pycache__/__init__.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/dit/__pycache__/dit.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/dit/dit.py
0 → 100644
View file @
1f5da520
# Modified from Meta DiT
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DiT: https://github.com/facebookresearch/DiT/tree/main
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
from
einops
import
rearrange
from
timm.models.vision_transformer
import
Mlp
from
opensora.acceleration.checkpoint
import
auto_grad_checkpoint
from
opensora.models.layers.blocks
import
(
Attention
,
CaptionEmbedder
,
FinalLayer
,
LabelEmbedder
,
PatchEmbed3D
,
TimestepEmbedder
,
approx_gelu
,
get_1d_sincos_pos_embed
,
get_2d_sincos_pos_embed
,
get_layernorm
,
modulate
,
)
from
opensora.registry
import
MODELS
from
opensora.utils.ckpt_utils
import
load_checkpoint
class
DiTBlock
(
nn
.
Module
):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def
__init__
(
self
,
hidden_size
,
num_heads
,
mlp_ratio
=
4.0
,
enable_flashattn
=
False
,
enable_layernorm_kernel
=
False
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
num_heads
=
num_heads
self
.
enable_flashattn
=
enable_flashattn
mlp_hidden_dim
=
int
(
hidden_size
*
mlp_ratio
)
self
.
norm1
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
attn
=
Attention
(
hidden_size
,
num_heads
=
num_heads
,
qkv_bias
=
True
,
enable_flashattn
=
enable_flashattn
,
)
self
.
norm2
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
mlp
=
Mlp
(
in_features
=
hidden_size
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
approx_gelu
,
drop
=
0
)
self
.
adaLN_modulation
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
6
*
hidden_size
,
bias
=
True
))
def
forward
(
self
,
x
,
c
):
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
self
.
adaLN_modulation
(
c
).
chunk
(
6
,
dim
=
1
)
x
=
x
+
gate_msa
.
unsqueeze
(
1
)
*
self
.
attn
(
modulate
(
self
.
norm1
,
x
,
shift_msa
,
scale_msa
))
x
=
x
+
gate_mlp
.
unsqueeze
(
1
)
*
self
.
mlp
(
modulate
(
self
.
norm2
,
x
,
shift_mlp
,
scale_mlp
))
return
x
@
MODELS
.
register_module
()
class
DiT
(
nn
.
Module
):
"""
Diffusion model with a Transformer backbone.
"""
def
__init__
(
self
,
input_size
=
(
16
,
32
,
32
),
in_channels
=
4
,
patch_size
=
(
1
,
2
,
2
),
hidden_size
=
1152
,
depth
=
28
,
num_heads
=
16
,
mlp_ratio
=
4.0
,
class_dropout_prob
=
0.1
,
learn_sigma
=
True
,
condition
=
"text"
,
no_temporal_pos_emb
=
False
,
caption_channels
=
512
,
model_max_length
=
77
,
dtype
=
torch
.
float32
,
enable_flashattn
=
False
,
enable_layernorm_kernel
=
False
,
enable_sequence_parallelism
=
False
,
):
super
().
__init__
()
self
.
learn_sigma
=
learn_sigma
self
.
in_channels
=
in_channels
self
.
out_channels
=
in_channels
*
2
if
learn_sigma
else
in_channels
self
.
hidden_size
=
hidden_size
self
.
patch_size
=
patch_size
self
.
input_size
=
input_size
num_patches
=
np
.
prod
([
input_size
[
i
]
//
patch_size
[
i
]
for
i
in
range
(
3
)])
self
.
num_patches
=
num_patches
self
.
num_temporal
=
input_size
[
0
]
//
patch_size
[
0
]
self
.
num_spatial
=
num_patches
//
self
.
num_temporal
self
.
num_heads
=
num_heads
self
.
dtype
=
dtype
self
.
use_text_encoder
=
not
condition
.
startswith
(
"label"
)
if
enable_flashattn
:
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
],
f
"Flash attention only supports float16 and bfloat16, but got
{
self
.
dtype
}
"
self
.
no_temporal_pos_emb
=
no_temporal_pos_emb
self
.
mlp_ratio
=
mlp_ratio
self
.
depth
=
depth
assert
enable_sequence_parallelism
is
False
,
"Sequence parallelism is not supported in DiT"
self
.
register_buffer
(
"pos_embed_spatial"
,
self
.
get_spatial_pos_embed
())
self
.
register_buffer
(
"pos_embed_temporal"
,
self
.
get_temporal_pos_embed
())
self
.
x_embedder
=
PatchEmbed3D
(
patch_size
,
in_channels
,
embed_dim
=
hidden_size
)
if
not
self
.
use_text_encoder
:
num_classes
=
int
(
condition
.
split
(
"_"
)[
-
1
])
self
.
y_embedder
=
LabelEmbedder
(
num_classes
,
hidden_size
,
class_dropout_prob
)
else
:
self
.
y_embedder
=
CaptionEmbedder
(
in_channels
=
caption_channels
,
hidden_size
=
hidden_size
,
uncond_prob
=
class_dropout_prob
,
act_layer
=
approx_gelu
,
token_num
=
1
,
# pooled token
)
self
.
t_embedder
=
TimestepEmbedder
(
hidden_size
)
self
.
blocks
=
nn
.
ModuleList
(
[
DiTBlock
(
hidden_size
,
num_heads
,
mlp_ratio
=
mlp_ratio
,
enable_flashattn
=
enable_flashattn
,
enable_layernorm_kernel
=
enable_layernorm_kernel
,
)
for
_
in
range
(
depth
)
]
)
self
.
final_layer
=
FinalLayer
(
hidden_size
,
np
.
prod
(
self
.
patch_size
),
self
.
out_channels
)
self
.
initialize_weights
()
self
.
enable_flashattn
=
enable_flashattn
self
.
enable_layernorm_kernel
=
enable_layernorm_kernel
def
get_spatial_pos_embed
(
self
):
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
input_size
[
1
]
//
self
.
patch_size
[
1
],
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
get_temporal_pos_embed
(
self
):
pos_embed
=
get_1d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
input_size
[
0
]
//
self
.
patch_size
[
0
],
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
unpatchify
(
self
,
x
):
c
=
self
.
out_channels
t
,
h
,
w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
pt
,
ph
,
pw
=
self
.
patch_size
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
t
,
h
,
w
,
pt
,
ph
,
pw
,
c
))
x
=
rearrange
(
x
,
"n t h w r p q c -> n c t r h p w q"
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
c
,
t
*
pt
,
h
*
ph
,
w
*
pw
))
return
imgs
def
forward
(
self
,
x
,
t
,
y
):
"""
Forward pass of DiT.
x: (B, C, T, H, W) tensor of inputs
t: (B,) tensor of diffusion timesteps
y: list of text
"""
# origin inputs should be float32, cast to specified dtype
x
=
x
.
to
(
self
.
dtype
)
if
self
.
use_text_encoder
:
y
=
y
.
to
(
self
.
dtype
)
# embedding
x
=
self
.
x_embedder
(
x
)
# (B, N, D)
x
=
rearrange
(
x
,
"b (t s) d -> b t s d"
,
t
=
self
.
num_temporal
,
s
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed_spatial
if
not
self
.
no_temporal_pos_emb
:
x
=
rearrange
(
x
,
"b t s d -> b s t d"
)
x
=
x
+
self
.
pos_embed_temporal
x
=
rearrange
(
x
,
"b s t d -> b (t s) d"
)
else
:
x
=
rearrange
(
x
,
"b t s d -> b (t s) d"
)
t
=
self
.
t_embedder
(
t
,
dtype
=
x
.
dtype
)
# (N, D)
y
=
self
.
y_embedder
(
y
,
self
.
training
)
# (N, D)
if
self
.
use_text_encoder
:
y
=
y
.
squeeze
(
1
).
squeeze
(
1
)
condition
=
t
+
y
# blocks
for
_
,
block
in
enumerate
(
self
.
blocks
):
c
=
condition
x
=
auto_grad_checkpoint
(
block
,
x
,
c
)
# (B, N, D)
# final process
x
=
self
.
final_layer
(
x
,
condition
)
# (B, N, num_patches * out_channels)
x
=
self
.
unpatchify
(
x
)
# (B, out_channels, T, H, W)
# cast to float32 for better accuracy
x
=
x
.
to
(
torch
.
float32
)
return
x
def
initialize_weights
(
self
):
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
if
module
.
weight
.
requires_grad_
:
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w
=
self
.
x_embedder
.
proj
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
nn
.
init
.
constant_
(
self
.
x_embedder
.
proj
.
bias
,
0
)
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in DiT blocks:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
adaLN_modulation
[
-
1
].
weight
,
0
)
nn
.
init
.
constant_
(
block
.
adaLN_modulation
[
-
1
].
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
final_layer
.
adaLN_modulation
[
-
1
].
weight
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
adaLN_modulation
[
-
1
].
bias
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
bias
,
0
)
# Zero-out text embedding layers:
if
self
.
use_text_encoder
:
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc1
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc2
.
weight
,
std
=
0.02
)
@
MODELS
.
register_module
(
"DiT-XL/2"
)
def
DiT_XL_2
(
from_pretrained
=
None
,
**
kwargs
):
model
=
DiT
(
depth
=
28
,
hidden_size
=
1152
,
patch_size
=
(
1
,
2
,
2
),
num_heads
=
16
,
**
kwargs
,
)
if
from_pretrained
is
not
None
:
load_checkpoint
(
model
,
from_pretrained
)
return
model
@
MODELS
.
register_module
(
"DiT-XL/2x2"
)
def
DiT_XL_2x2
(
from_pretrained
=
None
,
**
kwargs
):
model
=
DiT
(
depth
=
28
,
hidden_size
=
1152
,
patch_size
=
(
2
,
2
,
2
),
num_heads
=
16
,
**
kwargs
,
)
if
from_pretrained
is
not
None
:
load_checkpoint
(
model
,
from_pretrained
)
return
model
utils_data/opensora/models/latte/__init__.py
0 → 100644
View file @
1f5da520
from
.latte
import
Latte
,
Latte_XL_2
,
Latte_XL_2x2
utils_data/opensora/models/latte/__pycache__/__init__.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/latte/__pycache__/latte.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/latte/latte.py
0 → 100644
View file @
1f5da520
# Copyright 2024 Vchitect/Latte
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# Modified from Latte
#
#
# This file is mofied from https://github.com/Vchitect/Latte/blob/main/models/latte.py
#
# With references to:
# Latte: https://github.com/Vchitect/Latte
# DiT: https://github.com/facebookresearch/DiT/tree/main
import
torch
from
einops
import
rearrange
,
repeat
from
opensora.acceleration.checkpoint
import
auto_grad_checkpoint
from
opensora.models.dit
import
DiT
from
opensora.registry
import
MODELS
from
opensora.utils.ckpt_utils
import
load_checkpoint
@
MODELS
.
register_module
()
class
Latte
(
DiT
):
def
forward
(
self
,
x
,
t
,
y
):
"""
Forward pass of DiT.
x: (B, C, T, H, W) tensor of inputs
t: (B,) tensor of diffusion timesteps
y: list of text
"""
# origin inputs should be float32, cast to specified dtype
x
=
x
.
to
(
self
.
dtype
)
# embedding
x
=
self
.
x_embedder
(
x
)
# (B, N, D)
x
=
rearrange
(
x
,
"b (t s) d -> b t s d"
,
t
=
self
.
num_temporal
,
s
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed_spatial
x
=
rearrange
(
x
,
"b t s d -> b (t s) d"
)
t
=
self
.
t_embedder
(
t
,
dtype
=
x
.
dtype
)
# (N, D)
y
=
self
.
y_embedder
(
y
,
self
.
training
)
# (N, D)
if
self
.
use_text_encoder
:
y
=
y
.
squeeze
(
1
).
squeeze
(
1
)
condition
=
t
+
y
condition_spatial
=
repeat
(
condition
,
"b d -> (b t) d"
,
t
=
self
.
num_temporal
)
condition_temporal
=
repeat
(
condition
,
"b d -> (b s) d"
,
s
=
self
.
num_spatial
)
# blocks
for
i
,
block
in
enumerate
(
self
.
blocks
):
if
i
%
2
==
0
:
# spatial
x
=
rearrange
(
x
,
"b (t s) d -> (b t) s d"
,
t
=
self
.
num_temporal
,
s
=
self
.
num_spatial
)
c
=
condition_spatial
else
:
# temporal
x
=
rearrange
(
x
,
"b (t s) d -> (b s) t d"
,
t
=
self
.
num_temporal
,
s
=
self
.
num_spatial
)
c
=
condition_temporal
if
i
==
1
:
x
=
x
+
self
.
pos_embed_temporal
x
=
auto_grad_checkpoint
(
block
,
x
,
c
)
# (B, N, D)
if
i
%
2
==
0
:
x
=
rearrange
(
x
,
"(b t) s d -> b (t s) d"
,
t
=
self
.
num_temporal
,
s
=
self
.
num_spatial
)
else
:
x
=
rearrange
(
x
,
"(b s) t d -> b (t s) d"
,
t
=
self
.
num_temporal
,
s
=
self
.
num_spatial
)
# final process
x
=
self
.
final_layer
(
x
,
condition
)
# (B, N, num_patches * out_channels)
x
=
self
.
unpatchify
(
x
)
# (B, out_channels, T, H, W)
# cast to float32 for better accuracy
x
=
x
.
to
(
torch
.
float32
)
return
x
@
MODELS
.
register_module
(
"Latte-XL/2"
)
def
Latte_XL_2
(
from_pretrained
=
None
,
**
kwargs
):
model
=
Latte
(
depth
=
28
,
hidden_size
=
1152
,
patch_size
=
(
1
,
2
,
2
),
num_heads
=
16
,
**
kwargs
,
)
if
from_pretrained
is
not
None
:
load_checkpoint
(
model
,
from_pretrained
)
return
model
@
MODELS
.
register_module
(
"Latte-XL/2x2"
)
def
Latte_XL_2x2
(
from_pretrained
=
None
,
**
kwargs
):
model
=
Latte
(
depth
=
28
,
hidden_size
=
1152
,
patch_size
=
(
2
,
2
,
2
),
num_heads
=
16
,
**
kwargs
,
)
if
from_pretrained
is
not
None
:
load_checkpoint
(
model
,
from_pretrained
)
return
model
utils_data/opensora/models/layers/__init__.py
0 → 100644
View file @
1f5da520
utils_data/opensora/models/layers/__pycache__/__init__.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/layers/__pycache__/blocks.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/layers/__pycache__/timm_uvit.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/layers/blocks.py
0 → 100644
View file @
1f5da520
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha
# Latte: https://github.com/Vchitect/Latte
# DiT: https://github.com/facebookresearch/DiT/tree/main
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import
math
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
KeysView
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
import
xformers.ops
from
einops
import
rearrange
from
timm.models.vision_transformer
import
Mlp
from
opensora.acceleration.communications
import
all_to_all
,
split_forward_gather_backward
from
opensora.acceleration.parallel_states
import
get_sequence_parallel_group
# import ipdb
approx_gelu
=
lambda
:
nn
.
GELU
(
approximate
=
"tanh"
)
class
LlamaRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
#ipdb.set_trace()
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
#ipdb.set_trace()
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
def
get_layernorm
(
hidden_size
:
torch
.
Tensor
,
eps
:
float
,
affine
:
bool
,
use_kernel
:
bool
):
if
use_kernel
:
try
:
from
apex.normalization
import
FusedLayerNorm
return
FusedLayerNorm
(
hidden_size
,
elementwise_affine
=
affine
,
eps
=
eps
)
except
ImportError
:
raise
RuntimeError
(
"FusedLayerNorm not available. Please install apex."
)
else
:
return
nn
.
LayerNorm
(
hidden_size
,
eps
,
elementwise_affine
=
affine
)
def
modulate
(
norm_func
,
x
,
shift
,
scale
):
# Suppose x is (B, N, D), shift is (B, D), scale is (B, D)
dtype
=
x
.
dtype
x
=
norm_func
(
x
.
to
(
torch
.
float32
)).
to
(
dtype
)
x
=
x
*
(
scale
.
unsqueeze
(
1
)
+
1
)
+
shift
.
unsqueeze
(
1
)
x
=
x
.
to
(
dtype
)
return
x
def
t2i_modulate
(
x
,
shift
,
scale
):
return
x
*
(
1
+
scale
)
+
shift
# ===============================================
# General-purpose Layers
# ===============================================
class
PatchEmbed3D
(
nn
.
Module
):
"""Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: (2,4,4).
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def
__init__
(
self
,
patch_size
=
(
2
,
4
,
4
),
in_chans
=
3
,
embed_dim
=
96
,
padding
=
None
,
norm_layer
=
None
,
flatten
=
True
,
):
super
().
__init__
()
self
.
patch_size
=
patch_size
self
.
flatten
=
flatten
self
.
in_chans
=
in_chans
self
.
embed_dim
=
embed_dim
self
.
padding
=
padding
if
padding
is
not
None
:
self
.
proj
=
nn
.
Conv3d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
padding
=
padding
)
else
:
self
.
proj
=
nn
.
Conv3d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
if
norm_layer
is
not
None
:
self
.
norm
=
norm_layer
(
embed_dim
)
else
:
self
.
norm
=
None
def
forward
(
self
,
x
):
"""Forward function."""
if
self
.
padding
is
None
:
# padding
_
,
_
,
D
,
H
,
W
=
x
.
size
()
if
W
%
self
.
patch_size
[
2
]
!=
0
:
x
=
F
.
pad
(
x
,
(
0
,
self
.
patch_size
[
2
]
-
W
%
self
.
patch_size
[
2
]))
if
H
%
self
.
patch_size
[
1
]
!=
0
:
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
self
.
patch_size
[
1
]
-
H
%
self
.
patch_size
[
1
]))
if
D
%
self
.
patch_size
[
0
]
!=
0
:
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
0
,
0
,
self
.
patch_size
[
0
]
-
D
%
self
.
patch_size
[
0
]))
x
=
self
.
proj
(
x
)
# (B C T H W)
if
self
.
norm
is
not
None
:
D
,
Wh
,
Ww
=
x
.
size
(
2
),
x
.
size
(
3
),
x
.
size
(
4
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
norm
(
x
)
x
=
x
.
transpose
(
1
,
2
).
view
(
-
1
,
self
.
embed_dim
,
D
,
Wh
,
Ww
)
if
self
.
flatten
:
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCTHW -> BNC
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
num_heads
:
int
=
8
,
qkv_bias
:
bool
=
False
,
qk_norm
:
bool
=
False
,
attn_drop
:
float
=
0.0
,
proj_drop
:
float
=
0.0
,
norm_layer
:
nn
.
Module
=
nn
.
LayerNorm
,
enable_flashattn
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
assert
dim
%
num_heads
==
0
,
"dim should be divisible by num_heads"
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
enable_flashattn
=
enable_flashattn
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
q_norm
=
norm_layer
(
self
.
head_dim
)
if
qk_norm
else
nn
.
Identity
()
self
.
k_norm
=
norm_layer
(
self
.
head_dim
)
if
qk_norm
else
nn
.
Identity
()
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
)
qkv_shape
=
(
B
,
N
,
3
,
self
.
num_heads
,
self
.
head_dim
)
if
self
.
enable_flashattn
:
# here
qkv_permute_shape
=
(
2
,
0
,
1
,
3
,
4
)
else
:
qkv_permute_shape
=
(
2
,
0
,
3
,
1
,
4
)
qkv
=
qkv
.
view
(
qkv_shape
).
permute
(
qkv_permute_shape
)
q
,
k
,
v
=
qkv
.
unbind
(
0
)
q
,
k
=
self
.
q_norm
(
q
),
self
.
k_norm
(
k
)
if
self
.
enable_flashattn
:
from
flash_attn
import
flash_attn_func
x
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
self
.
attn_drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
scale
,
)
else
:
dtype
=
q
.
dtype
q
=
q
*
self
.
scale
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
# translate attn to float32
attn
=
attn
.
to
(
torch
.
float32
)
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
attn
.
to
(
dtype
)
# cast back attn to original dtype
attn
=
self
.
attn_drop
(
attn
)
x
=
attn
@
v
x_output_shape
=
(
B
,
N
,
C
)
if
not
self
.
enable_flashattn
:
x
=
x
.
transpose
(
1
,
2
)
x
=
x
.
reshape
(
x_output_shape
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Attention_QKNorm_RoPE
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
num_heads
:
int
=
8
,
qkv_bias
:
bool
=
False
,
qk_norm
:
bool
=
False
,
attn_drop
:
float
=
0.0
,
proj_drop
:
float
=
0.0
,
norm_layer
:
nn
.
Module
=
LlamaRMSNorm
,
enable_flashattn
:
bool
=
False
,
rope
=
None
,
)
->
None
:
super
().
__init__
()
assert
dim
%
num_heads
==
0
,
"dim should be divisible by num_heads"
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
enable_flashattn
=
enable_flashattn
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
q_norm
=
norm_layer
(
self
.
head_dim
)
if
qk_norm
else
nn
.
Identity
()
self
.
k_norm
=
norm_layer
(
self
.
head_dim
)
if
qk_norm
else
nn
.
Identity
()
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
rotary_emb
=
rope
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
)
qkv_shape
=
(
B
,
N
,
3
,
self
.
num_heads
,
self
.
head_dim
)
if
self
.
enable_flashattn
:
qkv_permute_shape
=
(
2
,
0
,
1
,
3
,
4
)
else
:
qkv_permute_shape
=
(
2
,
0
,
3
,
1
,
4
)
qkv
=
qkv
.
view
(
qkv_shape
).
permute
(
qkv_permute_shape
)
q
,
k
,
v
=
qkv
.
unbind
(
0
)
#ipdb.set_trace()
if
self
.
rotary_emb
is
not
None
:
q
=
self
.
rotary_emb
(
q
)
k
=
self
.
rotary_emb
(
k
)
#ipdb.set_trace()
q
,
k
=
self
.
q_norm
(
q
),
self
.
k_norm
(
k
)
#ipdb.set_trace()
if
self
.
enable_flashattn
:
from
flash_attn
import
flash_attn_func
x
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
self
.
attn_drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
scale
,
)
else
:
dtype
=
q
.
dtype
q
=
q
*
self
.
scale
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
# translate attn to float32
attn
=
attn
.
to
(
torch
.
float32
)
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
attn
.
to
(
dtype
)
# cast back attn to original dtype
attn
=
self
.
attn_drop
(
attn
)
x
=
attn
@
v
x_output_shape
=
(
B
,
N
,
C
)
if
not
self
.
enable_flashattn
:
x
=
x
.
transpose
(
1
,
2
)
x
=
x
.
reshape
(
x_output_shape
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
MaskedSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
num_heads
:
int
=
8
,
qkv_bias
:
bool
=
False
,
qk_norm
:
bool
=
False
,
attn_drop
:
float
=
0.0
,
proj_drop
:
float
=
0.0
,
norm_layer
:
nn
.
Module
=
LlamaRMSNorm
,
enable_flashattn
:
bool
=
False
,
rope
=
None
,
)
->
None
:
super
().
__init__
()
assert
dim
%
num_heads
==
0
,
"dim should be divisible by num_heads"
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
enable_flashattn
=
enable_flashattn
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
q_norm
=
norm_layer
(
self
.
head_dim
)
if
qk_norm
else
nn
.
Identity
()
self
.
k_norm
=
norm_layer
(
self
.
head_dim
)
if
qk_norm
else
nn
.
Identity
()
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
rotary_emb
=
rope
def
forward
(
self
,
x
,
mask
):
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
)
qkv_shape
=
(
B
,
N
,
3
,
self
.
num_heads
,
self
.
head_dim
)
qkv_permute_shape
=
(
2
,
0
,
3
,
1
,
4
)
qkv
=
qkv
.
view
(
qkv_shape
).
permute
(
qkv_permute_shape
)
q
,
k
,
v
=
qkv
.
unbind
(
0
)
# B H N C
#ipdb.set_trace()
if
self
.
rotary_emb
is
not
None
:
q
=
self
.
rotary_emb
(
q
)
k
=
self
.
rotary_emb
(
k
)
#ipdb.set_trace()
q
,
k
=
self
.
q_norm
(
q
),
self
.
k_norm
(
k
)
#ipdb.set_trace()
mask
=
mask
.
unsqueeze
(
1
).
unsqueeze
(
1
).
repeat
(
1
,
self
.
num_heads
,
1
,
1
).
to
(
torch
.
float32
)
# B H 1 N
dtype
=
q
.
dtype
q
=
q
*
self
.
scale
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
# translate attn to float32
attn
=
attn
.
to
(
torch
.
float32
)
attn
=
attn
.
masked_fill
(
mask
==
0
,
-
1e9
)
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
attn
.
to
(
dtype
)
# cast back attn to original dtype
attn
=
self
.
attn_drop
(
attn
)
x
=
attn
@
v
x_output_shape
=
(
B
,
N
,
C
)
x
=
x
.
transpose
(
1
,
2
)
x
=
x
.
reshape
(
x_output_shape
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
SeqParallelAttention
(
Attention
):
def
__init__
(
self
,
dim
:
int
,
num_heads
:
int
=
8
,
qkv_bias
:
bool
=
False
,
qk_norm
:
bool
=
False
,
attn_drop
:
float
=
0.0
,
proj_drop
:
float
=
0.0
,
norm_layer
:
nn
.
Module
=
nn
.
LayerNorm
,
enable_flashattn
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
dim
=
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_norm
=
qk_norm
,
attn_drop
=
attn_drop
,
proj_drop
=
proj_drop
,
norm_layer
=
norm_layer
,
enable_flashattn
=
enable_flashattn
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
B
,
N
,
C
=
x
.
shape
# for sequence parallel here, the N is a local sequence length
qkv
=
self
.
qkv
(
x
)
qkv_shape
=
(
B
,
N
,
3
,
self
.
num_heads
,
self
.
head_dim
)
qkv
=
qkv
.
view
(
qkv_shape
)
sp_group
=
get_sequence_parallel_group
()
# apply all_to_all to gather sequence and split attention heads
# [B, SUB_N, 3, NUM_HEAD, HEAD_DIM] -> [B, N, 3, NUM_HEAD_PER_DEVICE, HEAD_DIM]
qkv
=
all_to_all
(
qkv
,
sp_group
,
scatter_dim
=
3
,
gather_dim
=
1
)
if
self
.
enable_flashattn
:
qkv_permute_shape
=
(
2
,
0
,
1
,
3
,
4
)
# [3, B, N, NUM_HEAD_PER_DEVICE, HEAD_DIM]
else
:
qkv_permute_shape
=
(
2
,
0
,
3
,
1
,
4
)
# [3, B, NUM_HEAD_PER_DEVICE, N, HEAD_DIM]
qkv
=
qkv
.
permute
(
qkv_permute_shape
)
q
,
k
,
v
=
qkv
.
unbind
(
0
)
q
,
k
=
self
.
q_norm
(
q
),
self
.
k_norm
(
k
)
if
self
.
enable_flashattn
:
from
flash_attn
import
flash_attn_func
x
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
self
.
attn_drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
scale
,
)
else
:
dtype
=
q
.
dtype
q
=
q
*
self
.
scale
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
# translate attn to float32
attn
=
attn
.
to
(
torch
.
float32
)
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
attn
.
to
(
dtype
)
# cast back attn to original dtype
attn
=
self
.
attn_drop
(
attn
)
x
=
attn
@
v
if
not
self
.
enable_flashattn
:
x
=
x
.
transpose
(
1
,
2
)
# apply all to all to gather back attention heads and split sequence
# [B, N, NUM_HEAD_PER_DEVICE, HEAD_DIM] -> [B, SUB_N, NUM_HEAD, HEAD_DIM]
x
=
all_to_all
(
x
,
sp_group
,
scatter_dim
=
1
,
gather_dim
=
2
)
# reshape outputs back to [B, N, C]
x_output_shape
=
(
B
,
N
,
C
)
x
=
x
.
reshape
(
x_output_shape
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
MultiHeadCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_heads
,
attn_drop
=
0.0
,
proj_drop
=
0.0
):
super
(
MultiHeadCrossAttention
,
self
).
__init__
()
assert
d_model
%
num_heads
==
0
,
"d_model must be divisible by num_heads"
self
.
d_model
=
d_model
self
.
num_heads
=
num_heads
self
.
head_dim
=
d_model
//
num_heads
self
.
q_linear
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
kv_linear
=
nn
.
Linear
(
d_model
,
d_model
*
2
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
,
cond
,
mask
=
None
):
# query/value: img tokens; key: condition; mask: if padding tokens
B
,
N
,
C
=
x
.
shape
q
=
self
.
q_linear
(
x
).
view
(
1
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
kv
=
self
.
kv_linear
(
cond
).
view
(
1
,
-
1
,
2
,
self
.
num_heads
,
self
.
head_dim
)
k
,
v
=
kv
.
unbind
(
2
)
#ipdb.set_trace()
attn_bias
=
None
if
mask
is
not
None
:
attn_bias
=
xformers
.
ops
.
fmha
.
BlockDiagonalMask
.
from_seqlens
([
N
]
*
B
,
mask
)
x
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
p
=
self
.
attn_drop
.
p
,
attn_bias
=
attn_bias
)
#ipdb.set_trace()
x
=
x
.
view
(
B
,
-
1
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
MaskedMultiHeadCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_heads
,
attn_drop
=
0.0
,
proj_drop
=
0.0
):
super
(
MaskedMultiHeadCrossAttention
,
self
).
__init__
()
assert
d_model
%
num_heads
==
0
,
"d_model must be divisible by num_heads"
self
.
d_model
=
d_model
self
.
num_heads
=
num_heads
self
.
head_dim
=
d_model
//
num_heads
self
.
q_linear
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
kv_linear
=
nn
.
Linear
(
d_model
,
d_model
*
2
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
,
cond
,
mask
=
None
):
# query/value: img tokens; key: condition; mask: if padding tokens
B
,
S
,
C
=
x
.
shape
L
=
cond
.
shape
[
1
]
q
=
self
.
q_linear
(
x
).
view
(
B
,
S
,
self
.
num_heads
,
self
.
head_dim
)
kv
=
self
.
kv_linear
(
cond
).
view
(
B
,
L
,
2
,
self
.
num_heads
,
self
.
head_dim
)
k
,
v
=
kv
.
unbind
(
2
)
#ipdb.set_trace()
attn_bias
=
None
if
mask
is
not
None
:
attn_bias
=
mask
.
unsqueeze
(
1
).
unsqueeze
(
1
).
repeat
(
1
,
self
.
num_heads
,
S
,
1
).
to
(
q
.
dtype
)
# B H S L
exp
=
-
1e9
attn_bias
[
attn_bias
==
0
]
=
exp
attn_bias
[
attn_bias
==
1
]
=
0
x
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
p
=
self
.
attn_drop
.
p
,
attn_bias
=
attn_bias
)
#ipdb.set_trace()
x
=
x
.
view
(
B
,
-
1
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
MaskedMeanMultiHeadCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_heads
,
attn_drop
=
0.0
,
proj_drop
=
0.0
):
super
(
MaskedMeanMultiHeadCrossAttention
,
self
).
__init__
()
assert
d_model
%
num_heads
==
0
,
"d_model must be divisible by num_heads"
self
.
d_model
=
d_model
self
.
num_heads
=
num_heads
self
.
head_dim
=
d_model
//
num_heads
self
.
q_linear
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
kv_linear
=
nn
.
Linear
(
d_model
,
d_model
*
2
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
,
cond
,
mask
=
None
):
# query/value: img tokens; key: condition; mask: if padding tokens
B
,
T
,
S
,
C
=
x
.
shape
L
=
cond
.
shape
[
2
]
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
N
=
x
.
shape
[
1
]
cond
=
torch
.
mean
(
cond
,
dim
=
1
)
# B L C
mask
=
mask
[:,
0
,
:]
# B L
q
=
self
.
q_linear
(
x
).
view
(
B
,
N
,
self
.
num_heads
,
self
.
head_dim
)
kv
=
self
.
kv_linear
(
cond
).
view
(
B
,
L
,
2
,
self
.
num_heads
,
self
.
head_dim
)
k
,
v
=
kv
.
unbind
(
2
)
#ipdb.set_trace()
attn_bias
=
None
if
mask
is
not
None
:
attn_bias
=
mask
.
unsqueeze
(
1
).
unsqueeze
(
1
).
repeat
(
1
,
self
.
num_heads
,
N
,
1
).
to
(
q
.
dtype
)
# B H N L
exp
=
-
1e9
attn_bias
[
attn_bias
==
0
]
=
exp
attn_bias
[
attn_bias
==
1
]
=
0
x
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
p
=
self
.
attn_drop
.
p
,
attn_bias
=
attn_bias
)
#ipdb.set_trace()
x
=
rearrange
(
x
,
"B (T S) H C -> (B T) S (H C)"
,
T
=
T
,
S
=
S
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
x
=
rearrange
(
x
,
"(B T) S C -> B T S C"
,
B
=
B
,
T
=
T
)
return
x
class
LongShortMultiHeadCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_heads
,
attn_drop
=
0.0
,
proj_drop
=
0.0
):
super
(
LongShortMultiHeadCrossAttention
,
self
).
__init__
()
assert
d_model
%
num_heads
==
0
,
"d_model must be divisible by num_heads"
self
.
d_model
=
d_model
self
.
num_heads
=
num_heads
self
.
head_dim
=
d_model
//
num_heads
self
.
q_linear
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
kv_linear
=
nn
.
Linear
(
d_model
,
d_model
*
2
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
,
cond
,
mask
=
None
):
# query/value: img tokens; key: condition; mask: if padding tokens
B
,
N
,
C
=
x
.
shape
M
=
cond
.
shape
[
1
]
q
=
self
.
q_linear
(
x
).
view
(
B
,
N
,
self
.
num_heads
,
self
.
head_dim
)
kv
=
self
.
kv_linear
(
cond
).
view
(
B
,
M
,
2
,
self
.
num_heads
,
self
.
head_dim
)
k
,
v
=
kv
.
unbind
(
2
)
attn_bias
=
None
if
mask
is
not
None
:
attn_bias
=
xformers
.
ops
.
fmha
.
BlockDiagonalMask
.
from_seqlens
([
N
]
*
B
,
mask
)
x
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
p
=
self
.
attn_drop
.
p
,
attn_bias
=
attn_bias
)
x
=
x
.
view
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
MultiHeadV2TCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_heads
,
attn_drop
=
0.0
,
proj_drop
=
0.0
):
super
(
MultiHeadV2TCrossAttention
,
self
).
__init__
()
assert
d_model
%
num_heads
==
0
,
"d_model must be divisible by num_heads"
self
.
d_model
=
d_model
self
.
num_heads
=
num_heads
self
.
head_dim
=
d_model
//
num_heads
self
.
q_linear
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
kv_linear
=
nn
.
Linear
(
d_model
,
d_model
*
2
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
,
cond
,
mask
=
None
):
# query/value: condition; key: img tokens; mask: if padding tokens
B
,
N
,
C
=
cond
.
shape
q
=
self
.
q_linear
(
x
).
view
(
1
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
kv
=
self
.
kv_linear
(
cond
).
view
(
1
,
-
1
,
2
,
self
.
num_heads
,
self
.
head_dim
)
k
,
v
=
kv
.
unbind
(
2
)
#ipdb.set_trace()
attn_bias
=
None
if
mask
is
not
None
:
attn_bias
=
xformers
.
ops
.
fmha
.
BlockDiagonalMask
.
from_seqlens
(
mask
,
[
N
]
*
B
)
x
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
p
=
self
.
attn_drop
.
p
,
attn_bias
=
attn_bias
)
#ipdb.set_trace()
x
=
x
.
view
(
B
,
-
1
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
MultiHeadT2VCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_heads
,
attn_drop
=
0.0
,
proj_drop
=
0.0
):
super
(
MultiHeadT2VCrossAttention
,
self
).
__init__
()
assert
d_model
%
num_heads
==
0
,
"d_model must be divisible by num_heads"
self
.
d_model
=
d_model
self
.
num_heads
=
num_heads
self
.
head_dim
=
d_model
//
num_heads
self
.
q_linear
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
kv_linear
=
nn
.
Linear
(
d_model
,
d_model
*
2
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
,
cond
,
mask
=
None
):
# query/value: img tokens; key: condition; mask: if padding tokens
#ipdb.set_trace()
B
,
T
,
N
,
C
=
x
.
shape
x
=
rearrange
(
x
,
'B T N C -> (B T) N C'
)
q
=
self
.
q_linear
(
x
)
q
=
rearrange
(
q
,
'(B T) N C -> B T N C'
,
T
=
T
)
q
=
q
.
view
(
1
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
# 1(B T N) H C
kv
=
self
.
kv_linear
(
cond
)
kv
=
kv
.
view
(
1
,
-
1
,
2
,
self
.
num_heads
,
self
.
head_dim
)
# 1 N 2 H C
k
,
v
=
kv
.
unbind
(
2
)
#ipdb.set_trace()
attn_bias
=
None
if
mask
is
not
None
:
#mask = [m for m in mask for _ in range(T)]
attn_bias
=
xformers
.
ops
.
fmha
.
BlockDiagonalMask
.
from_seqlens
([
N
]
*
(
B
*
T
),
mask
)
x
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
p
=
self
.
attn_drop
.
p
,
attn_bias
=
attn_bias
)
#ipdb.set_trace()
x
=
x
.
view
(
B
,
T
,
N
,
C
)
x
=
rearrange
(
x
,
'B T N C -> (B T) N C'
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
x
=
rearrange
(
x
,
'(B T) N C -> B T N C'
,
T
=
T
)
return
x
class
FormerMultiHeadV2TCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_heads
,
attn_drop
=
0.0
,
proj_drop
=
0.0
):
super
(
FormerMultiHeadV2TCrossAttention
,
self
).
__init__
()
assert
d_model
%
num_heads
==
0
,
"d_model must be divisible by num_heads"
self
.
d_model
=
d_model
self
.
num_heads
=
num_heads
self
.
head_dim
=
d_model
//
num_heads
self
.
q_linear
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
kv_linear
=
nn
.
Linear
(
d_model
,
d_model
*
2
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
,
cond
,
mask
=
None
):
# x: text tokens; cond: img tokens; mask: if padding tokens
#ipdb.set_trace()
_
,
N
,
C
=
x
.
shape
# 1 N C
B
,
T
,
_
,
_
=
cond
.
shape
cond
=
rearrange
(
cond
,
'B T N C -> (B T) N C'
)
q
=
self
.
q_linear
(
x
)
q
=
q
.
view
(
1
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
# 1 N H C
kv
=
self
.
kv_linear
(
cond
)
kv
=
rearrange
(
kv
,
'(B T) N C -> B T N C'
,
B
=
B
)
M
=
kv
.
shape
[
2
]
# M = H * W
former_frame_index
=
torch
.
arange
(
T
)
-
1
former_frame_index
[
0
]
=
0
#ipdb.set_trace()
former_kv
=
kv
[:,
former_frame_index
]
former_kv
=
former_kv
.
view
(
1
,
-
1
,
2
,
self
.
num_heads
,
self
.
head_dim
)
# 1(B T N) 2 H C
former_k
,
former_v
=
former_kv
.
unbind
(
2
)
#ipdb.set_trace()
attn_bias
=
None
if
mask
is
not
None
:
#mask = [m for m in mask for _ in range(T)]
attn_bias
=
xformers
.
ops
.
fmha
.
BlockDiagonalMask
.
from_seqlens
(
mask
,
[
M
]
*
(
B
*
T
))
x
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
former_k
,
former_v
,
p
=
self
.
attn_drop
.
p
,
attn_bias
=
attn_bias
)
#ipdb.set_trace()
x
=
x
.
view
(
1
,
-
1
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
LatterMultiHeadV2TCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_heads
,
attn_drop
=
0.0
,
proj_drop
=
0.0
):
super
(
LatterMultiHeadV2TCrossAttention
,
self
).
__init__
()
assert
d_model
%
num_heads
==
0
,
"d_model must be divisible by num_heads"
self
.
d_model
=
d_model
self
.
num_heads
=
num_heads
self
.
head_dim
=
d_model
//
num_heads
self
.
q_linear
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
kv_linear
=
nn
.
Linear
(
d_model
,
d_model
*
2
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
,
cond
,
mask
=
None
):
# x: text tokens; cond: img tokens; mask: if padding tokens
#ipdb.set_trace()
_
,
N
,
C
=
x
.
shape
# 1 N C
B
,
T
,
_
,
_
=
cond
.
shape
cond
=
rearrange
(
cond
,
'B T N C -> (B T) N C'
)
q
=
self
.
q_linear
(
x
)
q
=
q
.
view
(
1
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
# 1 N H C
kv
=
self
.
kv_linear
(
cond
)
kv
=
rearrange
(
kv
,
'(B T) N C -> B T N C'
,
T
=
T
)
M
=
kv
.
shape
[
2
]
# M = H * W
latter_frame_index
=
torch
.
arange
(
T
)
+
1
latter_frame_index
[
-
1
]
=
T
-
1
#ipdb.set_trace()
latter_kv
=
kv
[:,
latter_frame_index
]
latter_kv
=
latter_kv
.
view
(
1
,
-
1
,
2
,
self
.
num_heads
,
self
.
head_dim
)
# 1(B T N) 2 H C
latter_k
,
latter_v
=
latter_kv
.
unbind
(
2
)
#ipdb.set_trace()
attn_bias
=
None
if
mask
is
not
None
:
# mask = [m for m in mask for _ in range(T)]
attn_bias
=
xformers
.
ops
.
fmha
.
BlockDiagonalMask
.
from_seqlens
(
mask
,
[
M
]
*
(
B
*
T
))
x
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
latter_k
,
latter_v
,
p
=
self
.
attn_drop
.
p
,
attn_bias
=
attn_bias
)
#ipdb.set_trace()
x
=
x
.
view
(
1
,
-
1
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
SeqParallelMultiHeadCrossAttention
(
MultiHeadCrossAttention
):
def
__init__
(
self
,
d_model
,
num_heads
,
attn_drop
=
0.0
,
proj_drop
=
0.0
,
):
super
().
__init__
(
d_model
=
d_model
,
num_heads
=
num_heads
,
attn_drop
=
attn_drop
,
proj_drop
=
proj_drop
)
def
forward
(
self
,
x
,
cond
,
mask
=
None
):
# query/value: img tokens; key: condition; mask: if padding tokens
sp_group
=
get_sequence_parallel_group
()
sp_size
=
dist
.
get_world_size
(
sp_group
)
B
,
SUB_N
,
C
=
x
.
shape
N
=
SUB_N
*
sp_size
# shape:
# q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
q
=
self
.
q_linear
(
x
).
view
(
B
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
kv
=
self
.
kv_linear
(
cond
).
view
(
B
,
-
1
,
2
,
self
.
num_heads
,
self
.
head_dim
)
k
,
v
=
kv
.
unbind
(
2
)
# apply all_to_all to gather sequence and split attention heads
q
=
all_to_all
(
q
,
sp_group
,
scatter_dim
=
2
,
gather_dim
=
1
)
k
=
split_forward_gather_backward
(
k
,
get_sequence_parallel_group
(),
dim
=
2
,
grad_scale
=
"down"
)
v
=
split_forward_gather_backward
(
v
,
get_sequence_parallel_group
(),
dim
=
2
,
grad_scale
=
"down"
)
q
=
q
.
view
(
1
,
-
1
,
self
.
num_heads
//
sp_size
,
self
.
head_dim
)
k
=
k
.
view
(
1
,
-
1
,
self
.
num_heads
//
sp_size
,
self
.
head_dim
)
v
=
v
.
view
(
1
,
-
1
,
self
.
num_heads
//
sp_size
,
self
.
head_dim
)
# compute attention
attn_bias
=
None
if
mask
is
not
None
:
attn_bias
=
xformers
.
ops
.
fmha
.
BlockDiagonalMask
.
from_seqlens
([
N
]
*
B
,
mask
)
x
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
p
=
self
.
attn_drop
.
p
,
attn_bias
=
attn_bias
)
# apply all to all to gather back attention heads and scatter sequence
x
=
x
.
view
(
B
,
-
1
,
self
.
num_heads
//
sp_size
,
self
.
head_dim
)
x
=
all_to_all
(
x
,
sp_group
,
scatter_dim
=
1
,
gather_dim
=
2
)
# apply output projection
x
=
x
.
view
(
B
,
-
1
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
FinalLayer
(
nn
.
Module
):
"""
The final layer of DiT.
"""
def
__init__
(
self
,
hidden_size
,
num_patch
,
out_channels
):
super
().
__init__
()
self
.
norm_final
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
linear
=
nn
.
Linear
(
hidden_size
,
num_patch
*
out_channels
,
bias
=
True
)
self
.
adaLN_modulation
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
2
*
hidden_size
,
bias
=
True
))
def
forward
(
self
,
x
,
c
):
shift
,
scale
=
self
.
adaLN_modulation
(
c
).
chunk
(
2
,
dim
=
1
)
x
=
modulate
(
self
.
norm_final
,
x
,
shift
,
scale
)
x
=
self
.
linear
(
x
)
return
x
class
T2IFinalLayer
(
nn
.
Module
):
"""
The final layer of PixArt.
"""
def
__init__
(
self
,
hidden_size
,
num_patch
,
out_channels
):
super
().
__init__
()
self
.
norm_final
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
linear
=
nn
.
Linear
(
hidden_size
,
num_patch
*
out_channels
,
bias
=
True
)
self
.
scale_shift_table
=
nn
.
Parameter
(
torch
.
randn
(
2
,
hidden_size
)
/
hidden_size
**
0.5
)
self
.
out_channels
=
out_channels
def
forward
(
self
,
x
,
t
):
shift
,
scale
=
(
self
.
scale_shift_table
[
None
]
+
t
[:,
None
]).
chunk
(
2
,
dim
=
1
)
x
=
t2i_modulate
(
self
.
norm_final
(
x
),
shift
,
scale
)
x
=
self
.
linear
(
x
)
return
x
# ==================
# Frequency Layers
# ==================
class
SpatialFrequencyBlcok
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
SpatialFrequencyBlcok
,
self
).
__init__
()
self
.
act_layer
=
nn
.
GELU
(
approximate
=
"tanh"
)
# Process low-frequency
self
.
low_freq_layer1
=
nn
.
Linear
(
in_features
=
dim
,
out_features
=
2
*
dim
)
self
.
low_freq_layer2
=
nn
.
Linear
(
in_features
=
2
*
dim
,
out_features
=
dim
)
# Process high-frequency
self
.
high_freq_layer1
=
nn
.
Linear
(
in_features
=
dim
,
out_features
=
2
*
dim
)
self
.
high_freq_layer2
=
nn
.
Linear
(
in_features
=
2
*
dim
,
out_features
=
dim
)
def
forward
(
self
,
x
,
use_cfg
=
True
):
if
use_cfg
:
# x shape: torch.Size([4, 4096, 1152])
high_1
,
low_1
,
high_2
,
low_2
=
torch
.
chunk
(
x
,
4
,
dim
=
0
)
highfreq
=
torch
.
cat
((
high_1
,
high_2
),
dim
=
0
)
# torch.Size([2, 4096, 1152])
lowfreq
=
torch
.
cat
((
low_1
,
low_2
),
dim
=
0
)
# torch.Size([2, 4096, 1152])
# extention
highfreq
,
hf_info
=
self
.
high_freq_layer1
(
highfreq
).
chunk
(
2
,
dim
=-
1
)
lowfreq
,
lf_info
=
self
.
low_freq_layer1
(
lowfreq
).
chunk
(
2
,
dim
=-
1
)
# fusion
high_1
,
high_2
=
self
.
high_freq_layer2
(
torch
.
cat
((
highfreq
,
lf_info
),
dim
=-
1
)).
chunk
(
2
,
dim
=
0
)
low_1
,
low_2
=
self
.
low_freq_layer2
(
torch
.
cat
((
lowfreq
,
hf_info
),
dim
=-
1
)).
chunk
(
2
,
dim
=
0
)
out
=
torch
.
cat
((
high_1
,
low_1
,
high_2
,
low_2
),
dim
=
0
)
else
:
highfreq
,
lowfreq
=
torch
.
chunk
(
x
,
2
,
dim
=
0
)
# extention
highfreq
,
hf_info
=
self
.
high_freq_layer1
(
highfreq
).
chunk
(
2
,
dim
=-
1
)
lowfreq
,
lf_info
=
self
.
low_freq_layer1
(
lowfreq
).
chunk
(
2
,
dim
=-
1
)
# fusion
highfreq
=
self
.
high_freq_layer2
(
torch
.
cat
((
highfreq
,
lf_info
),
dim
=-
1
))
lowfreq
=
self
.
low_freq_layer2
(
torch
.
cat
((
lowfreq
,
hf_info
),
dim
=-
1
))
out
=
torch
.
cat
((
highfreq
,
lowfreq
),
dim
=
0
)
return
out
class
TemporalFrequencyBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
qkv_bias
,
attn_drop
=
0.0
,
proj_drop
=
0.0
):
super
(
TemporalFrequencyBlock
,
self
).
__init__
()
assert
dim
%
num_heads
==
0
,
"dim must be divisible by num_heads"
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
*
2
,
dim
*
3
,
bias
=
qkv_bias
)
# self.qkv2 = nn.Linear(dim, dim * 3, bias=qkv_bias)
# self.reduction = nn.Linear(dim * 6, dim * 3, bias=qkv_bias)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
,
cond
):
# query/value: img tokens; key: condition; mask: if padding tokens
B
,
N
,
C
=
x
.
shape
# qkv1 = self.qkv1(x)
# qkv2 = self.qkv2(cond)
qkv
=
torch
.
cat
((
x
,
cond
),
dim
=-
1
)
qkv
=
self
.
qkv
(
qkv
)
qkv_shape
=
(
B
,
N
,
3
,
self
.
num_heads
,
self
.
head_dim
)
qkv_permute_shape
=
(
2
,
0
,
3
,
1
,
4
)
qkv
=
qkv
.
view
(
qkv_shape
).
permute
(
qkv_permute_shape
)
q
,
k
,
v
=
qkv
.
unbind
(
0
)
dtype
=
q
.
dtype
q
=
q
*
self
.
scale
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
# translate attn to float32
attn
=
attn
.
to
(
torch
.
float32
)
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
attn
.
to
(
dtype
)
# cast back attn to original dtype
attn
=
self
.
attn_drop
(
attn
)
x
=
attn
@
v
x_output_shape
=
(
B
,
N
,
C
)
x
=
x
.
transpose
(
1
,
2
)
x
=
x
.
reshape
(
x_output_shape
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
def
zero_module
(
module
):
for
p
in
module
.
parameters
():
nn
.
init
.
zeros_
(
p
)
return
module
class
Encoder_3D
(
nn
.
Module
):
"""
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
model) to encode image-space conditions ... into feature maps ..."
"""
def
__init__
(
self
,
conditioning_embedding_channels
:
int
,
# conditioning_channels: int = 3,
block_out_channels
:
Tuple
[
int
]
=
(
16
,
32
,
96
,
256
),
):
super
().
__init__
()
# self.conv_in = nn.Conv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self
.
blocks
=
nn
.
ModuleList
([])
for
i
in
range
(
len
(
block_out_channels
)
-
1
):
channel_in
=
block_out_channels
[
i
]
channel_out
=
block_out_channels
[
i
+
1
]
self
.
blocks
.
append
(
nn
.
Conv3d
(
channel_in
,
channel_in
,
kernel_size
=
(
3
,
3
,
3
),
padding
=
1
,
stride
=
1
))
self
.
blocks
.
append
(
nn
.
Conv3d
(
channel_in
,
channel_out
,
kernel_size
=
(
3
,
3
,
3
),
padding
=
1
,
stride
=
(
1
,
2
,
2
)))
self
.
conv_out
=
zero_module
(
nn
.
Conv3d
(
block_out_channels
[
-
1
],
conditioning_embedding_channels
,
kernel_size
=
(
3
,
3
,
3
),
padding
=
1
,
stride
=
1
)
)
def
forward
(
self
,
embedding
):
# embedding = self.conv_in(conditioning)
# embedding = F.silu(embedding)
for
block
in
self
.
blocks
:
embedding
=
block
(
embedding
)
embedding
=
F
.
silu
(
embedding
)
embedding
=
self
.
conv_out
(
embedding
)
return
embedding
# ===============================================
# Embedding Layers for Timesteps and Class Labels
# ===============================================
class
TimestepEmbedder
(
nn
.
Module
):
"""
Embeds scalar timesteps into vector representations.
"""
def
__init__
(
self
,
hidden_size
,
frequency_embedding_size
=
256
):
super
().
__init__
()
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
frequency_embedding_size
,
hidden_size
,
bias
=
True
),
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
True
),
)
self
.
frequency_embedding_size
=
frequency_embedding_size
@
staticmethod
def
timestep_embedding
(
t
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
)
freqs
=
freqs
.
to
(
device
=
t
.
device
)
args
=
t
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
def
forward
(
self
,
t
,
dtype
):
t_freq
=
self
.
timestep_embedding
(
t
,
self
.
frequency_embedding_size
)
if
t_freq
.
dtype
!=
dtype
:
t_freq
=
t_freq
.
to
(
dtype
)
t_emb
=
self
.
mlp
(
t_freq
)
return
t_emb
class
LabelEmbedder
(
nn
.
Module
):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def
__init__
(
self
,
num_classes
,
hidden_size
,
dropout_prob
):
super
().
__init__
()
use_cfg_embedding
=
dropout_prob
>
0
self
.
embedding_table
=
nn
.
Embedding
(
num_classes
+
use_cfg_embedding
,
hidden_size
)
self
.
num_classes
=
num_classes
self
.
dropout_prob
=
dropout_prob
def
token_drop
(
self
,
labels
,
force_drop_ids
=
None
):
"""
Drops labels to enable classifier-free guidance.
"""
if
force_drop_ids
is
None
:
drop_ids
=
torch
.
rand
(
labels
.
shape
[
0
]).
cuda
()
<
self
.
dropout_prob
else
:
drop_ids
=
force_drop_ids
==
1
labels
=
torch
.
where
(
drop_ids
,
self
.
num_classes
,
labels
)
return
labels
def
forward
(
self
,
labels
,
train
,
force_drop_ids
=
None
):
use_dropout
=
self
.
dropout_prob
>
0
if
(
train
and
use_dropout
)
or
(
force_drop_ids
is
not
None
):
labels
=
self
.
token_drop
(
labels
,
force_drop_ids
)
return
self
.
embedding_table
(
labels
)
class
SizeEmbedder
(
TimestepEmbedder
):
"""
Embeds scalar timesteps into vector representations.
"""
def
__init__
(
self
,
hidden_size
,
frequency_embedding_size
=
256
):
super
().
__init__
(
hidden_size
=
hidden_size
,
frequency_embedding_size
=
frequency_embedding_size
)
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
frequency_embedding_size
,
hidden_size
,
bias
=
True
),
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
True
),
)
self
.
frequency_embedding_size
=
frequency_embedding_size
self
.
outdim
=
hidden_size
def
forward
(
self
,
s
,
bs
):
if
s
.
ndim
==
1
:
s
=
s
[:,
None
]
assert
s
.
ndim
==
2
if
s
.
shape
[
0
]
!=
bs
:
s
=
s
.
repeat
(
bs
//
s
.
shape
[
0
],
1
)
assert
s
.
shape
[
0
]
==
bs
b
,
dims
=
s
.
shape
[
0
],
s
.
shape
[
1
]
s
=
rearrange
(
s
,
"b d -> (b d)"
)
s_freq
=
self
.
timestep_embedding
(
s
,
self
.
frequency_embedding_size
).
to
(
self
.
dtype
)
s_emb
=
self
.
mlp
(
s_freq
)
s_emb
=
rearrange
(
s_emb
,
"(b d) d2 -> b (d d2)"
,
b
=
b
,
d
=
dims
,
d2
=
self
.
outdim
)
return
s_emb
@
property
def
dtype
(
self
):
return
next
(
self
.
parameters
()).
dtype
class
CaptionEmbedder
(
nn
.
Module
):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def
__init__
(
self
,
in_channels
,
hidden_size
,
uncond_prob
,
act_layer
=
nn
.
GELU
(
approximate
=
"tanh"
),
token_num
=
120
):
super
().
__init__
()
self
.
y_proj
=
Mlp
(
in_features
=
in_channels
,
hidden_features
=
hidden_size
,
out_features
=
hidden_size
,
act_layer
=
act_layer
,
drop
=
0
)
self
.
register_buffer
(
"y_embedding"
,
nn
.
Parameter
(
torch
.
randn
(
token_num
,
in_channels
)
/
in_channels
**
0.5
))
self
.
uncond_prob
=
uncond_prob
def
token_drop
(
self
,
caption
,
force_drop_ids
=
None
):
"""
Drops labels to enable classifier-free guidance.
"""
if
force_drop_ids
is
None
:
drop_ids
=
torch
.
rand
(
caption
.
shape
[
0
]).
cuda
()
<
self
.
uncond_prob
else
:
drop_ids
=
force_drop_ids
==
1
caption
=
torch
.
where
(
drop_ids
[:,
None
,
None
,
None
],
self
.
y_embedding
,
caption
)
return
caption
def
forward
(
self
,
caption
,
train
,
force_drop_ids
=
None
):
if
train
:
assert
caption
.
shape
[
2
:]
==
self
.
y_embedding
.
shape
use_dropout
=
self
.
uncond_prob
>
0
if
(
train
and
use_dropout
)
or
(
force_drop_ids
is
not
None
):
caption
=
self
.
token_drop
(
caption
,
force_drop_ids
)
caption
=
self
.
y_proj
(
caption
)
return
caption
# ===============================================
# Sine/Cosine Positional Embedding Functions
# ===============================================
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
,
cls_token
=
False
,
extra_tokens
=
0
,
scale
=
1.0
,
base_size
=
None
):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if
not
isinstance
(
grid_size
,
tuple
):
grid_size
=
(
grid_size
,
grid_size
)
grid_h
=
np
.
arange
(
grid_size
[
0
],
dtype
=
np
.
float32
)
/
scale
grid_w
=
np
.
arange
(
grid_size
[
1
],
dtype
=
np
.
float32
)
/
scale
if
base_size
is
not
None
:
grid_h
*=
base_size
/
grid_size
[
0
]
grid_w
*=
base_size
/
grid_size
[
1
]
grid
=
np
.
meshgrid
(
grid_w
,
grid_h
)
# here w goes first
grid
=
np
.
stack
(
grid
,
axis
=
0
)
grid
=
grid
.
reshape
([
2
,
1
,
grid_size
[
1
],
grid_size
[
0
]])
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
)
if
cls_token
and
extra_tokens
>
0
:
pos_embed
=
np
.
concatenate
([
np
.
zeros
([
extra_tokens
,
embed_dim
]),
pos_embed
],
axis
=
0
)
return
pos_embed
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
):
assert
embed_dim
%
2
==
0
# use half of dimensions to encode grid_h
emb_h
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
0
])
# (H*W, D/2)
emb_w
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
1
])
# (H*W, D/2)
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=
1
)
# (H*W, D)
return
emb
def
get_1d_sincos_pos_embed
(
embed_dim
,
length
,
scale
=
1.0
):
pos
=
np
.
arange
(
0
,
length
)[...,
None
]
/
scale
return
get_1d_sincos_pos_embed_from_grid
(
embed_dim
,
pos
)
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
,
pos
):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float64
)
omega
/=
embed_dim
/
2.0
omega
=
1.0
/
10000
**
omega
# (D/2,)
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
np
.
einsum
(
"m,d->md"
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
return
emb
Prev
1
…
7
8
9
10
11
12
13
14
15
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment