Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
STAR
Commits
1f5da520
Commit
1f5da520
authored
Dec 05, 2025
by
yangzhong
Browse files
git init
parents
Pipeline
#3144
failed with stages
in 0 seconds
Changes
326
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1493 additions
and
0 deletions
+1493
-0
utils_data/opensora/models/stdit/stdit_qknorm_rope.py
utils_data/opensora/models/stdit/stdit_qknorm_rope.py
+423
-0
utils_data/opensora/models/text_encoder/__init__.py
utils_data/opensora/models/text_encoder/__init__.py
+3
-0
utils_data/opensora/models/text_encoder/__pycache__/__init__.cpython-39.pyc
...a/models/text_encoder/__pycache__/__init__.cpython-39.pyc
+0
-0
utils_data/opensora/models/text_encoder/__pycache__/classes.cpython-39.pyc
...ra/models/text_encoder/__pycache__/classes.cpython-39.pyc
+0
-0
utils_data/opensora/models/text_encoder/__pycache__/clip.cpython-39.pyc
...nsora/models/text_encoder/__pycache__/clip.cpython-39.pyc
+0
-0
utils_data/opensora/models/text_encoder/__pycache__/t5.cpython-39.pyc
...pensora/models/text_encoder/__pycache__/t5.cpython-39.pyc
+0
-0
utils_data/opensora/models/text_encoder/classes.py
utils_data/opensora/models/text_encoder/classes.py
+20
-0
utils_data/opensora/models/text_encoder/clip.py
utils_data/opensora/models/text_encoder/clip.py
+114
-0
utils_data/opensora/models/text_encoder/t5.py
utils_data/opensora/models/text_encoder/t5.py
+335
-0
utils_data/opensora/models/vae/__init__.py
utils_data/opensora/models/vae/__init__.py
+1
-0
utils_data/opensora/models/vae/__pycache__/__init__.cpython-39.pyc
...a/opensora/models/vae/__pycache__/__init__.cpython-39.pyc
+0
-0
utils_data/opensora/models/vae/__pycache__/vae.cpython-39.pyc
...s_data/opensora/models/vae/__pycache__/vae.cpython-39.pyc
+0
-0
utils_data/opensora/models/vae/vae.py
utils_data/opensora/models/vae/vae.py
+82
-0
utils_data/opensora/models/vsr/__pycache__/fdie_arch.cpython-39.pyc
.../opensora/models/vsr/__pycache__/fdie_arch.cpython-39.pyc
+0
-0
utils_data/opensora/models/vsr/__pycache__/safmn_arch.cpython-39.pyc
...opensora/models/vsr/__pycache__/safmn_arch.cpython-39.pyc
+0
-0
utils_data/opensora/models/vsr/__pycache__/sfr_lftg.cpython-39.pyc
...a/opensora/models/vsr/__pycache__/sfr_lftg.cpython-39.pyc
+0
-0
utils_data/opensora/models/vsr/fdie_arch.py
utils_data/opensora/models/vsr/fdie_arch.py
+205
-0
utils_data/opensora/models/vsr/safmn_arch.py
utils_data/opensora/models/vsr/safmn_arch.py
+193
-0
utils_data/opensora/models/vsr/sfr_lftg.py
utils_data/opensora/models/vsr/sfr_lftg.py
+73
-0
utils_data/opensora/registry.py
utils_data/opensora/registry.py
+44
-0
No files found.
utils_data/opensora/models/stdit/stdit_qknorm_rope.py
0 → 100644
View file @
1f5da520
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
einops
import
rearrange
from
timm.models.layers
import
DropPath
from
timm.models.vision_transformer
import
Mlp
from
opensora.acceleration.checkpoint
import
auto_grad_checkpoint
from
opensora.acceleration.communications
import
gather_forward_split_backward
,
split_forward_gather_backward
from
opensora.acceleration.parallel_states
import
get_sequence_parallel_group
from
opensora.models.layers.blocks
import
(
Attention
,
Attention_QKNorm_RoPE
,
CaptionEmbedder
,
MultiHeadCrossAttention
,
PatchEmbed3D
,
SeqParallelAttention
,
SeqParallelMultiHeadCrossAttention
,
T2IFinalLayer
,
TimestepEmbedder
,
approx_gelu
,
get_1d_sincos_pos_embed
,
get_2d_sincos_pos_embed
,
get_layernorm
,
t2i_modulate
,
)
from
opensora.registry
import
MODELS
from
opensora.utils.ckpt_utils
import
load_checkpoint
# import ipdb
from
rotary_embedding_torch
import
RotaryEmbedding
class
STDiTBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
d_s
=
None
,
d_t
=
None
,
mlp_ratio
=
4.0
,
drop_path
=
0.0
,
enable_flashattn
=
False
,
enable_layernorm_kernel
=
False
,
enable_sequence_parallelism
=
False
,
rope
=
None
,
qk_norm
=
False
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
enable_flashattn
=
enable_flashattn
self
.
_enable_sequence_parallelism
=
enable_sequence_parallelism
if
enable_sequence_parallelism
:
self
.
attn_cls
=
SeqParallelAttention
self
.
mha_cls
=
SeqParallelMultiHeadCrossAttention
else
:
# here
self
.
attn_cls
=
Attention_QKNorm_RoPE
self
.
mha_cls
=
MultiHeadCrossAttention
self
.
norm1
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
attn
=
self
.
attn_cls
(
hidden_size
,
num_heads
=
num_heads
,
qkv_bias
=
True
,
enable_flashattn
=
enable_flashattn
,
qk_norm
=
qk_norm
,
)
self
.
cross_attn
=
self
.
mha_cls
(
hidden_size
,
num_heads
)
self
.
norm2
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
self
.
mlp
=
Mlp
(
in_features
=
hidden_size
,
hidden_features
=
int
(
hidden_size
*
mlp_ratio
),
act_layer
=
approx_gelu
,
drop
=
0
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
scale_shift_table
=
nn
.
Parameter
(
torch
.
randn
(
6
,
hidden_size
)
/
hidden_size
**
0.5
)
self
.
scale_shift_table_temporal
=
nn
.
Parameter
(
torch
.
randn
(
3
,
hidden_size
)
/
hidden_size
**
0.5
)
# new
# temporal attention
self
.
d_s
=
d_s
self
.
d_t
=
d_t
if
self
.
_enable_sequence_parallelism
:
sp_size
=
dist
.
get_world_size
(
get_sequence_parallel_group
())
# make sure d_t is divisible by sp_size
assert
d_t
%
sp_size
==
0
self
.
d_t
=
d_t
//
sp_size
self
.
norm_temp
=
get_layernorm
(
hidden_size
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
enable_layernorm_kernel
)
# new
self
.
attn_temp
=
self
.
attn_cls
(
hidden_size
,
num_heads
=
num_heads
,
qkv_bias
=
True
,
enable_flashattn
=
self
.
enable_flashattn
,
rope
=
rope
,
qk_norm
=
qk_norm
,
)
def
forward
(
self
,
x
,
y
,
t
,
t_temp
,
mask
=
None
,
tpe
=
None
):
B
,
N
,
C
=
x
.
shape
#ipdb.set_trace()
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
(
self
.
scale_shift_table
[
None
]
+
t
.
reshape
(
B
,
6
,
-
1
)
).
chunk
(
6
,
dim
=
1
)
shift_tmp
,
scale_tmp
,
gate_tmp
=
(
self
.
scale_shift_table_temporal
[
None
]
+
t_temp
.
reshape
(
B
,
3
,
-
1
)
).
chunk
(
3
,
dim
=
1
)
x_m
=
t2i_modulate
(
self
.
norm1
(
x
),
shift_msa
,
scale_msa
)
# spatial branch
x_s
=
rearrange
(
x_m
,
"B (T S) C -> (B T) S C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
x_s
=
self
.
attn
(
x_s
)
x_s
=
rearrange
(
x_s
,
"(B T) S C -> B (T S) C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
x
=
x
+
self
.
drop_path
(
gate_msa
*
x_s
)
# modulate
#ipdb.set_trace()
x_m
=
t2i_modulate
(
self
.
norm_temp
(
x
),
shift_tmp
,
scale_tmp
)
# temporal branch
x_t
=
rearrange
(
x_m
,
"B (T S) C -> (B S) T C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
if
tpe
is
not
None
:
x_t
=
x_t
+
tpe
x_t
=
self
.
attn_temp
(
x_t
)
x_t
=
rearrange
(
x_t
,
"(B S) T C -> B (T S) C"
,
T
=
self
.
d_t
,
S
=
self
.
d_s
)
x
=
x
+
self
.
drop_path
(
gate_tmp
*
x_t
)
# cross attn
x
=
x
+
self
.
cross_attn
(
x
,
y
,
mask
)
# mlp
x
=
x
+
self
.
drop_path
(
gate_mlp
*
self
.
mlp
(
t2i_modulate
(
self
.
norm2
(
x
),
shift_mlp
,
scale_mlp
)))
return
x
@
MODELS
.
register_module
()
class
STDiT_QKNorm_RoPE
(
nn
.
Module
):
def
__init__
(
self
,
input_size
=
(
1
,
32
,
32
),
in_channels
=
4
,
patch_size
=
(
1
,
2
,
2
),
hidden_size
=
1152
,
depth
=
28
,
num_heads
=
16
,
mlp_ratio
=
4.0
,
class_dropout_prob
=
0.1
,
pred_sigma
=
True
,
drop_path
=
0.0
,
no_temporal_pos_emb
=
False
,
caption_channels
=
4096
,
model_max_length
=
120
,
dtype
=
torch
.
float32
,
space_scale
=
1.0
,
time_scale
=
1.0
,
freeze
=
None
,
enable_flashattn
=
False
,
enable_layernorm_kernel
=
False
,
enable_sequence_parallelism
=
False
,
qk_norm
=
False
,
rope
=
False
,
):
super
().
__init__
()
self
.
pred_sigma
=
pred_sigma
self
.
in_channels
=
in_channels
self
.
out_channels
=
in_channels
*
2
if
pred_sigma
else
in_channels
self
.
hidden_size
=
hidden_size
self
.
patch_size
=
patch_size
self
.
input_size
=
input_size
num_patches
=
np
.
prod
([
input_size
[
i
]
//
patch_size
[
i
]
for
i
in
range
(
3
)])
self
.
num_patches
=
num_patches
self
.
num_temporal
=
input_size
[
0
]
//
patch_size
[
0
]
self
.
num_spatial
=
num_patches
//
self
.
num_temporal
self
.
num_heads
=
num_heads
self
.
dtype
=
dtype
self
.
no_temporal_pos_emb
=
no_temporal_pos_emb
self
.
depth
=
depth
self
.
mlp_ratio
=
mlp_ratio
self
.
enable_flashattn
=
enable_flashattn
self
.
enable_layernorm_kernel
=
enable_layernorm_kernel
self
.
space_scale
=
space_scale
self
.
time_scale
=
time_scale
self
.
register_buffer
(
"pos_embed"
,
self
.
get_spatial_pos_embed
())
self
.
register_buffer
(
"pos_embed_temporal"
,
self
.
get_temporal_pos_embed
())
self
.
x_embedder
=
PatchEmbed3D
(
patch_size
,
in_channels
,
hidden_size
)
self
.
t_embedder
=
TimestepEmbedder
(
hidden_size
)
self
.
t_block
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
6
*
hidden_size
,
bias
=
True
))
self
.
t_block_temp
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
3
*
hidden_size
,
bias
=
True
))
self
.
y_embedder
=
CaptionEmbedder
(
in_channels
=
caption_channels
,
hidden_size
=
hidden_size
,
uncond_prob
=
class_dropout_prob
,
act_layer
=
approx_gelu
,
token_num
=
model_max_length
,
)
drop_path
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path
,
depth
)]
#ipdb.set_trace()
if
rope
:
RoPE
=
RotaryEmbedding
(
dim
=
self
.
hidden_size
//
self
.
num_heads
)
self
.
rope
=
RoPE
.
rotate_queries_or_keys
else
:
self
.
rope
=
None
#ipdb.set_trace()
self
.
blocks
=
nn
.
ModuleList
(
[
STDiTBlock
(
self
.
hidden_size
,
self
.
num_heads
,
mlp_ratio
=
self
.
mlp_ratio
,
drop_path
=
drop_path
[
i
],
enable_flashattn
=
self
.
enable_flashattn
,
enable_layernorm_kernel
=
self
.
enable_layernorm_kernel
,
enable_sequence_parallelism
=
enable_sequence_parallelism
,
d_t
=
self
.
num_temporal
,
d_s
=
self
.
num_spatial
,
rope
=
self
.
rope
,
qk_norm
=
qk_norm
,
)
for
i
in
range
(
self
.
depth
)
]
)
self
.
final_layer
=
T2IFinalLayer
(
hidden_size
,
np
.
prod
(
self
.
patch_size
),
self
.
out_channels
)
# init model
self
.
initialize_weights
()
self
.
initialize_temporal
()
if
freeze
is
not
None
:
assert
freeze
in
[
"not_temporal"
,
"text"
]
if
freeze
==
"not_temporal"
:
self
.
freeze_not_temporal
()
elif
freeze
==
"text"
:
self
.
freeze_text
()
# sequence parallel related configs
self
.
enable_sequence_parallelism
=
enable_sequence_parallelism
if
enable_sequence_parallelism
:
self
.
sp_rank
=
dist
.
get_rank
(
get_sequence_parallel_group
())
else
:
self
.
sp_rank
=
None
def
forward
(
self
,
x
,
timestep
,
y
,
mask
=
None
):
"""
Forward pass of STDiT.
Args:
x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
timestep (torch.Tensor): diffusion time steps; of shape [B]
y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
Returns:
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
"""
x
=
x
.
to
(
self
.
dtype
)
timestep
=
timestep
.
to
(
self
.
dtype
)
y
=
y
.
to
(
self
.
dtype
)
# embedding
x
=
self
.
x_embedder
(
x
)
# [B, N, C]
x
=
rearrange
(
x
,
"B (T S) C -> B T S C"
,
T
=
self
.
num_temporal
,
S
=
self
.
num_spatial
)
x
=
x
+
self
.
pos_embed
x
=
rearrange
(
x
,
"B T S C -> B (T S) C"
)
# shard over the sequence dim if sp is enabled
if
self
.
enable_sequence_parallelism
:
x
=
split_forward_gather_backward
(
x
,
get_sequence_parallel_group
(),
dim
=
1
,
grad_scale
=
"down"
)
t
=
self
.
t_embedder
(
timestep
,
dtype
=
x
.
dtype
)
# [B, C]
t0
=
self
.
t_block
(
t
)
# [B, C]
t0_temp
=
self
.
t_block_temp
(
t
)
# [B, C]
y
=
self
.
y_embedder
(
y
,
self
.
training
)
# [B, 1, N_token, C]
if
mask
is
not
None
:
if
mask
.
shape
[
0
]
!=
y
.
shape
[
0
]:
mask
=
mask
.
repeat
(
y
.
shape
[
0
]
//
mask
.
shape
[
0
],
1
)
mask
=
mask
.
squeeze
(
1
).
squeeze
(
1
)
y
=
y
.
squeeze
(
1
).
masked_select
(
mask
.
unsqueeze
(
-
1
)
!=
0
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
y_lens
=
mask
.
sum
(
dim
=
1
).
tolist
()
else
:
y_lens
=
[
y
.
shape
[
2
]]
*
y
.
shape
[
0
]
y
=
y
.
squeeze
(
1
).
view
(
1
,
-
1
,
x
.
shape
[
-
1
])
# blocks
for
i
,
block
in
enumerate
(
self
.
blocks
):
if
i
==
0
:
if
self
.
enable_sequence_parallelism
:
tpe
=
torch
.
chunk
(
self
.
pos_embed_temporal
,
dist
.
get_world_size
(
get_sequence_parallel_group
()),
dim
=
1
)[
self
.
sp_rank
].
contiguous
()
else
:
tpe
=
self
.
pos_embed_temporal
else
:
tpe
=
None
x
=
auto_grad_checkpoint
(
block
,
x
,
y
,
t0
,
t0_temp
,
y_lens
,
tpe
)
if
self
.
enable_sequence_parallelism
:
x
=
gather_forward_split_backward
(
x
,
get_sequence_parallel_group
(),
dim
=
1
,
grad_scale
=
"up"
)
# x.shape: [B, N, C]
# final process
x
=
self
.
final_layer
(
x
,
t
)
# [B, N, C=T_p * H_p * W_p * C_out]
x
=
self
.
unpatchify
(
x
)
# [B, C_out, T, H, W]
# cast to float32 for better accuracy
x
=
x
.
to
(
torch
.
float32
)
return
x
def
unpatchify
(
self
,
x
):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t
,
N_h
,
N_w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
T_p
,
H_p
,
W_p
=
self
.
patch_size
x
=
rearrange
(
x
,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)"
,
N_t
=
N_t
,
N_h
=
N_h
,
N_w
=
N_w
,
T_p
=
T_p
,
H_p
=
H_p
,
W_p
=
W_p
,
C_out
=
self
.
out_channels
,
)
return
x
def
unpatchify_old
(
self
,
x
):
c
=
self
.
out_channels
t
,
h
,
w
=
[
self
.
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
pt
,
ph
,
pw
=
self
.
patch_size
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
t
,
h
,
w
,
pt
,
ph
,
pw
,
c
))
x
=
rearrange
(
x
,
"n t h w r p q c -> n c t r h p w q"
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
c
,
t
*
pt
,
h
*
ph
,
w
*
pw
))
return
imgs
def
get_spatial_pos_embed
(
self
,
grid_size
=
None
):
if
grid_size
is
None
:
grid_size
=
self
.
input_size
[
1
:]
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
hidden_size
,
(
grid_size
[
0
]
//
self
.
patch_size
[
1
],
grid_size
[
1
]
//
self
.
patch_size
[
2
]),
scale
=
self
.
space_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
get_temporal_pos_embed
(
self
):
pos_embed
=
get_1d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
input_size
[
0
]
//
self
.
patch_size
[
0
],
scale
=
self
.
time_scale
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
freeze_not_temporal
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"attn_temp"
not
in
n
:
p
.
requires_grad
=
False
def
freeze_text
(
self
):
for
n
,
p
in
self
.
named_parameters
():
if
"cross_attn"
in
n
:
p
.
requires_grad
=
False
def
initialize_temporal
(
self
):
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
attn_temp
.
proj
.
bias
,
0
)
def
initialize_weights
(
self
):
# Initialize transformer layers:
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w
=
self
.
x_embedder
.
proj
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
# Initialize timestep embedding MLP:
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
0
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_embedder
.
mlp
[
2
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block
[
1
].
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
t_block_temp
[
1
].
weight
,
std
=
0.02
)
# Initialize caption embedding MLP:
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc1
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
y_embedder
.
y_proj
.
fc2
.
weight
,
std
=
0.02
)
# Zero-out adaLN modulation layers in PixArt blocks:
for
block
in
self
.
blocks
:
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
weight
,
0
)
nn
.
init
.
constant_
(
block
.
cross_attn
.
proj
.
bias
,
0
)
# Zero-out output layers:
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
final_layer
.
linear
.
bias
,
0
)
@
MODELS
.
register_module
(
"STDiT_QKNorm_RoPE_XL/2"
)
def
STDiT_QKNorm_RoPE_XL_2
(
from_pretrained
=
None
,
**
kwargs
):
#ipdb.set_trace()
model
=
STDiT_QKNorm_RoPE
(
depth
=
28
,
hidden_size
=
1152
,
patch_size
=
(
1
,
2
,
2
),
num_heads
=
16
,
**
kwargs
)
if
from_pretrained
is
not
None
:
load_checkpoint
(
model
,
from_pretrained
)
return
model
utils_data/opensora/models/text_encoder/__init__.py
0 → 100644
View file @
1f5da520
from
.classes
import
ClassEncoder
from
.clip
import
ClipEncoder
from
.t5
import
T5Encoder
utils_data/opensora/models/text_encoder/__pycache__/__init__.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/text_encoder/__pycache__/classes.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/text_encoder/__pycache__/clip.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/text_encoder/__pycache__/t5.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/text_encoder/classes.py
0 → 100644
View file @
1f5da520
import
torch
from
opensora.registry
import
MODELS
@
MODELS
.
register_module
(
"classes"
)
class
ClassEncoder
:
def
__init__
(
self
,
num_classes
,
model_max_length
=
None
,
device
=
"cuda"
,
dtype
=
torch
.
float
):
self
.
num_classes
=
num_classes
self
.
y_embedder
=
None
self
.
model_max_length
=
model_max_length
self
.
output_dim
=
None
self
.
device
=
device
def
encode
(
self
,
text
):
return
dict
(
y
=
torch
.
tensor
([
int
(
t
)
for
t
in
text
]).
to
(
self
.
device
))
def
null
(
self
,
n
):
return
torch
.
tensor
([
self
.
num_classes
]
*
n
).
to
(
self
.
device
)
utils_data/opensora/models/text_encoder/clip.py
0 → 100644
View file @
1f5da520
# Copyright 2024 Vchitect/Latte
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# Modified from Latte
#
# This file is adapted from the Latte project.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# Latte: https://github.com/Vchitect/Latte
# DiT: https://github.com/facebookresearch/DiT/tree/main
# --------------------------------------------------------
import
torch
import
torch.nn
as
nn
import
transformers
from
transformers
import
CLIPTextModel
,
CLIPTokenizer
from
opensora.registry
import
MODELS
transformers
.
logging
.
set_verbosity_error
()
class
AbstractEncoder
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
encode
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
class
FrozenCLIPEmbedder
(
AbstractEncoder
):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def
__init__
(
self
,
path
=
"openai/clip-vit-huge-patch14"
,
device
=
"cuda"
,
max_length
=
77
):
super
().
__init__
()
self
.
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
path
)
self
.
transformer
=
CLIPTextModel
.
from_pretrained
(
path
)
self
.
device
=
device
self
.
max_length
=
max_length
self
.
_freeze
()
def
_freeze
(
self
):
self
.
transformer
=
self
.
transformer
.
eval
()
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
text
):
batch_encoding
=
self
.
tokenizer
(
text
,
truncation
=
True
,
max_length
=
self
.
max_length
,
return_length
=
True
,
return_overflowing_tokens
=
False
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
)
tokens
=
batch_encoding
[
"input_ids"
].
to
(
self
.
device
)
outputs
=
self
.
transformer
(
input_ids
=
tokens
)
z
=
outputs
.
last_hidden_state
pooled_z
=
outputs
.
pooler_output
return
z
,
pooled_z
def
encode
(
self
,
text
):
return
self
(
text
)
@
MODELS
.
register_module
(
"clip"
)
class
ClipEncoder
:
"""
Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
"""
def
__init__
(
self
,
from_pretrained
,
model_max_length
=
77
,
device
=
"cuda"
,
dtype
=
torch
.
float
,
):
super
().
__init__
()
assert
from_pretrained
is
not
None
,
"Please specify the path to the T5 model"
self
.
text_encoder
=
FrozenCLIPEmbedder
(
path
=
from_pretrained
,
max_length
=
model_max_length
).
to
(
device
,
dtype
)
self
.
y_embedder
=
None
self
.
model_max_length
=
model_max_length
self
.
output_dim
=
self
.
text_encoder
.
transformer
.
config
.
hidden_size
def
encode
(
self
,
text
):
_
,
pooled_embeddings
=
self
.
text_encoder
.
encode
(
text
)
y
=
pooled_embeddings
.
unsqueeze
(
1
).
unsqueeze
(
1
)
return
dict
(
y
=
y
)
def
null
(
self
,
n
):
null_y
=
self
.
y_embedder
.
y_embedding
[
None
].
repeat
(
n
,
1
,
1
)[:,
None
]
return
null_y
def
to
(
self
,
dtype
):
self
.
text_encoder
=
self
.
text_encoder
.
to
(
dtype
)
return
self
utils_data/opensora/models/text_encoder/t5.py
0 → 100644
View file @
1f5da520
# Adapted from PixArt
#
# Copyright (C) 2023 PixArt-alpha/PixArt-alpha
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha
# T5: https://github.com/google-research/text-to-text-transfer-transformer
# --------------------------------------------------------
import
html
import
re
import
ftfy
import
torch
from
transformers
import
AutoTokenizer
,
T5EncoderModel
from
opensora.registry
import
MODELS
class
T5Embedder
:
def
__init__
(
self
,
device
,
from_pretrained
=
None
,
*
,
cache_dir
=
None
,
hf_token
=
None
,
use_text_preprocessing
=
True
,
t5_model_kwargs
=
None
,
torch_dtype
=
None
,
use_offload_folder
=
None
,
model_max_length
=
120
,
local_files_only
=
False
,
):
self
.
device
=
torch
.
device
(
device
)
self
.
torch_dtype
=
torch_dtype
or
torch
.
bfloat16
self
.
cache_dir
=
cache_dir
if
t5_model_kwargs
is
None
:
t5_model_kwargs
=
{
"low_cpu_mem_usage"
:
True
,
"torch_dtype"
:
self
.
torch_dtype
,
}
if
use_offload_folder
is
not
None
:
t5_model_kwargs
[
"offload_folder"
]
=
use_offload_folder
t5_model_kwargs
[
"device_map"
]
=
{
"shared"
:
self
.
device
,
"encoder.embed_tokens"
:
self
.
device
,
"encoder.block.0"
:
self
.
device
,
"encoder.block.1"
:
self
.
device
,
"encoder.block.2"
:
self
.
device
,
"encoder.block.3"
:
self
.
device
,
"encoder.block.4"
:
self
.
device
,
"encoder.block.5"
:
self
.
device
,
"encoder.block.6"
:
self
.
device
,
"encoder.block.7"
:
self
.
device
,
"encoder.block.8"
:
self
.
device
,
"encoder.block.9"
:
self
.
device
,
"encoder.block.10"
:
self
.
device
,
"encoder.block.11"
:
self
.
device
,
"encoder.block.12"
:
"disk"
,
"encoder.block.13"
:
"disk"
,
"encoder.block.14"
:
"disk"
,
"encoder.block.15"
:
"disk"
,
"encoder.block.16"
:
"disk"
,
"encoder.block.17"
:
"disk"
,
"encoder.block.18"
:
"disk"
,
"encoder.block.19"
:
"disk"
,
"encoder.block.20"
:
"disk"
,
"encoder.block.21"
:
"disk"
,
"encoder.block.22"
:
"disk"
,
"encoder.block.23"
:
"disk"
,
"encoder.final_layer_norm"
:
"disk"
,
"encoder.dropout"
:
"disk"
,
}
else
:
t5_model_kwargs
[
"device_map"
]
=
{
"shared"
:
self
.
device
,
"encoder"
:
self
.
device
,
}
self
.
use_text_preprocessing
=
use_text_preprocessing
self
.
hf_token
=
hf_token
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
from_pretrained
,
cache_dir
=
cache_dir
,
local_files_only
=
local_files_only
,
)
self
.
model
=
T5EncoderModel
.
from_pretrained
(
from_pretrained
,
cache_dir
=
cache_dir
,
local_files_only
=
local_files_only
,
**
t5_model_kwargs
,
).
eval
()
self
.
model_max_length
=
model_max_length
def
get_text_embeddings
(
self
,
texts
):
text_tokens_and_mask
=
self
.
tokenizer
(
texts
,
max_length
=
self
.
model_max_length
,
padding
=
"max_length"
,
truncation
=
True
,
return_attention_mask
=
True
,
add_special_tokens
=
True
,
return_tensors
=
"pt"
,
)
input_ids
=
text_tokens_and_mask
[
"input_ids"
].
to
(
self
.
device
)
attention_mask
=
text_tokens_and_mask
[
"attention_mask"
].
to
(
self
.
device
)
with
torch
.
no_grad
():
text_encoder_embs
=
self
.
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
)[
"last_hidden_state"
].
detach
()
return
text_encoder_embs
,
attention_mask
@
MODELS
.
register_module
(
"t5"
)
class
T5Encoder
:
def
__init__
(
self
,
from_pretrained
=
None
,
model_max_length
=
120
,
device
=
"cuda"
,
dtype
=
torch
.
float
,
cache_dir
=
None
,
shardformer
=
False
,
local_files_only
=
False
,
):
assert
from_pretrained
is
not
None
,
"Please specify the path to the T5 model"
self
.
t5
=
T5Embedder
(
device
=
device
,
torch_dtype
=
dtype
,
from_pretrained
=
from_pretrained
,
cache_dir
=
cache_dir
,
model_max_length
=
model_max_length
,
local_files_only
=
local_files_only
,
)
self
.
t5
.
model
.
to
(
dtype
=
dtype
)
self
.
y_embedder
=
None
self
.
model_max_length
=
model_max_length
self
.
output_dim
=
self
.
t5
.
model
.
config
.
d_model
self
.
dtype
=
dtype
if
shardformer
:
self
.
shardformer_t5
()
def
shardformer_t5
(
self
):
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
from
opensora.acceleration.shardformer.policy.t5_encoder
import
T5EncoderPolicy
from
opensora.utils.misc
import
requires_grad
shard_config
=
ShardConfig
(
tensor_parallel_process_group
=
None
,
pipeline_stage_manager
=
None
,
enable_tensor_parallelism
=
False
,
enable_fused_normalization
=
False
,
enable_flash_attention
=
False
,
enable_jit_fused
=
True
,
enable_sequence_parallelism
=
False
,
enable_sequence_overlap
=
False
,
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
optim_model
,
_
=
shard_former
.
optimize
(
self
.
t5
.
model
,
policy
=
T5EncoderPolicy
())
self
.
t5
.
model
=
optim_model
.
to
(
self
.
dtype
)
# ensure the weights are frozen
requires_grad
(
self
.
t5
.
model
,
False
)
def
encode
(
self
,
text
):
caption_embs
,
emb_masks
=
self
.
t5
.
get_text_embeddings
(
text
)
caption_embs
=
caption_embs
[:,
None
]
return
dict
(
y
=
caption_embs
,
mask
=
emb_masks
)
def
null
(
self
,
n
):
null_y
=
self
.
y_embedder
.
y_embedding
[
None
].
repeat
(
n
,
1
,
1
)[:,
None
]
return
null_y
def
basic_clean
(
text
):
text
=
ftfy
.
fix_text
(
text
)
text
=
html
.
unescape
(
html
.
unescape
(
text
))
return
text
.
strip
()
BAD_PUNCT_REGEX
=
re
.
compile
(
r
"["
+
"#®•©™&@·º½¾¿¡§~"
+
"\)"
+
"\("
+
"\]"
+
"\["
+
"\}"
+
"\{"
+
"\|"
+
"
\\
"
+
"\/"
+
"\*"
+
r
"]{1,}"
)
# noqa
def
clean_caption
(
caption
):
import
urllib.parse
as
ul
from
bs4
import
BeautifulSoup
caption
=
str
(
caption
)
caption
=
ul
.
unquote_plus
(
caption
)
caption
=
caption
.
strip
().
lower
()
caption
=
re
.
sub
(
"<person>"
,
"person"
,
caption
)
# urls:
caption
=
re
.
sub
(
r
"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))"
,
# noqa
""
,
caption
,
)
# regex for urls
caption
=
re
.
sub
(
r
"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))"
,
# noqa
""
,
caption
,
)
# regex for urls
# html:
caption
=
BeautifulSoup
(
caption
,
features
=
"html.parser"
).
text
# @<nickname>
caption
=
re
.
sub
(
r
"@[\w\d]+\b"
,
""
,
caption
)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption
=
re
.
sub
(
r
"[\u31c0-\u31ef]+"
,
""
,
caption
)
caption
=
re
.
sub
(
r
"[\u31f0-\u31ff]+"
,
""
,
caption
)
caption
=
re
.
sub
(
r
"[\u3200-\u32ff]+"
,
""
,
caption
)
caption
=
re
.
sub
(
r
"[\u3300-\u33ff]+"
,
""
,
caption
)
caption
=
re
.
sub
(
r
"[\u3400-\u4dbf]+"
,
""
,
caption
)
caption
=
re
.
sub
(
r
"[\u4dc0-\u4dff]+"
,
""
,
caption
)
caption
=
re
.
sub
(
r
"[\u4e00-\u9fff]+"
,
""
,
caption
)
#######################################################
# все виды тире / all types of dash --> "-"
caption
=
re
.
sub
(
r
"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+"
,
# noqa
"-"
,
caption
,
)
# кавычки к одному стандарту
caption
=
re
.
sub
(
r
"[`´«»“”¨]"
,
'"'
,
caption
)
caption
=
re
.
sub
(
r
"[‘’]"
,
"'"
,
caption
)
# "
caption
=
re
.
sub
(
r
""?"
,
""
,
caption
)
# &
caption
=
re
.
sub
(
r
"&"
,
""
,
caption
)
# ip adresses:
caption
=
re
.
sub
(
r
"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}"
,
" "
,
caption
)
# article ids:
caption
=
re
.
sub
(
r
"\d:\d\d\s+$"
,
""
,
caption
)
# \n
caption
=
re
.
sub
(
r
"\\n"
,
" "
,
caption
)
# "#123"
caption
=
re
.
sub
(
r
"#\d{1,3}\b"
,
""
,
caption
)
# "#12345.."
caption
=
re
.
sub
(
r
"#\d{5,}\b"
,
""
,
caption
)
# "123456.."
caption
=
re
.
sub
(
r
"\b\d{6,}\b"
,
""
,
caption
)
# filenames:
caption
=
re
.
sub
(
r
"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)"
,
""
,
caption
)
#
caption
=
re
.
sub
(
r
"[\"\']{2,}"
,
r
'"'
,
caption
)
# """AUSVERKAUFT"""
caption
=
re
.
sub
(
r
"[\.]{2,}"
,
r
" "
,
caption
)
# """AUSVERKAUFT"""
caption
=
re
.
sub
(
BAD_PUNCT_REGEX
,
r
" "
,
caption
)
# ***AUSVERKAUFT***, #AUSVERKAUFT
caption
=
re
.
sub
(
r
"\s+\.\s+"
,
r
" "
,
caption
)
# " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2
=
re
.
compile
(
r
"(?:\-|\_)"
)
if
len
(
re
.
findall
(
regex2
,
caption
))
>
3
:
caption
=
re
.
sub
(
regex2
,
" "
,
caption
)
caption
=
basic_clean
(
caption
)
caption
=
re
.
sub
(
r
"\b[a-zA-Z]{1,3}\d{3,15}\b"
,
""
,
caption
)
# jc6640
caption
=
re
.
sub
(
r
"\b[a-zA-Z]+\d+[a-zA-Z]+\b"
,
""
,
caption
)
# jc6640vc
caption
=
re
.
sub
(
r
"\b\d+[a-zA-Z]+\d+\b"
,
""
,
caption
)
# 6640vc231
caption
=
re
.
sub
(
r
"(worldwide\s+)?(free\s+)?shipping"
,
""
,
caption
)
caption
=
re
.
sub
(
r
"(free\s)?download(\sfree)?"
,
""
,
caption
)
caption
=
re
.
sub
(
r
"\bclick\b\s(?:for|on)\s\w+"
,
""
,
caption
)
caption
=
re
.
sub
(
r
"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?"
,
""
,
caption
)
caption
=
re
.
sub
(
r
"\bpage\s+\d+\b"
,
""
,
caption
)
caption
=
re
.
sub
(
r
"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b"
,
r
" "
,
caption
)
# j2d1a2a...
caption
=
re
.
sub
(
r
"\b\d+\.?\d*[xх×]\d+\.?\d*\b"
,
""
,
caption
)
caption
=
re
.
sub
(
r
"\b\s+\:\s+"
,
r
": "
,
caption
)
caption
=
re
.
sub
(
r
"(\D[,\./])\b"
,
r
"\1 "
,
caption
)
caption
=
re
.
sub
(
r
"\s+"
,
" "
,
caption
)
caption
.
strip
()
caption
=
re
.
sub
(
r
"^[\"\']([\w\W]+)[\"\']$"
,
r
"\1"
,
caption
)
caption
=
re
.
sub
(
r
"^[\'\_,\-\:;]"
,
r
""
,
caption
)
caption
=
re
.
sub
(
r
"[\'\_,\-\:\-\+]$"
,
r
""
,
caption
)
caption
=
re
.
sub
(
r
"^\.\S+$"
,
""
,
caption
)
return
caption
.
strip
()
def
text_preprocessing
(
text
,
use_text_preprocessing
:
bool
=
True
):
if
use_text_preprocessing
:
# The exact text cleaning as was in the training stage:
text
=
clean_caption
(
text
)
text
=
clean_caption
(
text
)
return
text
else
:
return
text
.
lower
().
strip
()
utils_data/opensora/models/vae/__init__.py
0 → 100644
View file @
1f5da520
from
.vae
import
VideoAutoencoderKL
,
VideoAutoencoderKLTemporalDecoder
utils_data/opensora/models/vae/__pycache__/__init__.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/vae/__pycache__/vae.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/vae/vae.py
0 → 100644
View file @
1f5da520
import
torch
import
torch.nn
as
nn
from
diffusers.models
import
AutoencoderKL
,
AutoencoderKLTemporalDecoder
from
einops
import
rearrange
from
opensora.registry
import
MODELS
@
MODELS
.
register_module
()
class
VideoAutoencoderKL
(
nn
.
Module
):
def
__init__
(
self
,
from_pretrained
=
None
,
micro_batch_size
=
None
):
super
().
__init__
()
self
.
module
=
AutoencoderKL
.
from_pretrained
(
from_pretrained
)
self
.
out_channels
=
self
.
module
.
config
.
latent_channels
self
.
patch_size
=
(
1
,
8
,
8
)
self
.
micro_batch_size
=
micro_batch_size
def
encode
(
self
,
x
):
# x: (B, C, T, H, W)
B
=
x
.
shape
[
0
]
x
=
rearrange
(
x
,
"B C T H W -> (B T) C H W"
)
if
self
.
micro_batch_size
is
None
:
x
=
self
.
module
.
encode
(
x
).
latent_dist
.
sample
().
mul_
(
0.18215
)
else
:
bs
=
self
.
micro_batch_size
x_out
=
[]
for
i
in
range
(
0
,
x
.
shape
[
0
],
bs
):
x_bs
=
x
[
i
:
i
+
bs
]
x_bs
=
self
.
module
.
encode
(
x_bs
).
latent_dist
.
sample
().
mul_
(
0.18215
)
x_out
.
append
(
x_bs
)
x
=
torch
.
cat
(
x_out
,
dim
=
0
)
x
=
rearrange
(
x
,
"(B T) C H W -> B C T H W"
,
B
=
B
)
return
x
def
decode
(
self
,
x
):
# x: (B, C, T, H, W)
B
=
x
.
shape
[
0
]
x
=
rearrange
(
x
,
"B C T H W -> (B T) C H W"
)
if
self
.
micro_batch_size
is
None
:
x
=
self
.
module
.
decode
(
x
/
0.18215
).
sample
else
:
bs
=
self
.
micro_batch_size
x_out
=
[]
for
i
in
range
(
0
,
x
.
shape
[
0
],
bs
):
x_bs
=
x
[
i
:
i
+
bs
]
x_bs
=
self
.
module
.
decode
(
x_bs
/
0.18215
).
sample
x_out
.
append
(
x_bs
)
x
=
torch
.
cat
(
x_out
,
dim
=
0
)
x
=
rearrange
(
x
,
"(B T) C H W -> B C T H W"
,
B
=
B
)
return
x
def
get_latent_size
(
self
,
input_size
):
for
i
in
range
(
3
):
assert
input_size
[
i
]
%
self
.
patch_size
[
i
]
==
0
,
"Input size must be divisible by patch size"
input_size
=
[
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
return
input_size
@
MODELS
.
register_module
()
class
VideoAutoencoderKLTemporalDecoder
(
nn
.
Module
):
def
__init__
(
self
,
from_pretrained
=
None
):
super
().
__init__
()
self
.
module
=
AutoencoderKLTemporalDecoder
.
from_pretrained
(
from_pretrained
)
self
.
out_channels
=
self
.
module
.
config
.
latent_channels
self
.
patch_size
=
(
1
,
8
,
8
)
def
encode
(
self
,
x
):
raise
NotImplementedError
def
decode
(
self
,
x
):
B
,
_
,
T
=
x
.
shape
[:
3
]
x
=
rearrange
(
x
,
"B C T H W -> (B T) C H W"
)
x
=
self
.
module
.
decode
(
x
/
0.18215
,
num_frames
=
T
).
sample
x
=
rearrange
(
x
,
"(B T) C H W -> B C T H W"
,
B
=
B
)
return
x
def
get_latent_size
(
self
,
input_size
):
for
i
in
range
(
3
):
assert
input_size
[
i
]
%
self
.
patch_size
[
i
]
==
0
,
"Input size must be divisible by patch size"
input_size
=
[
input_size
[
i
]
//
self
.
patch_size
[
i
]
for
i
in
range
(
3
)]
return
input_size
utils_data/opensora/models/vsr/__pycache__/fdie_arch.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/vsr/__pycache__/safmn_arch.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/vsr/__pycache__/sfr_lftg.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
utils_data/opensora/models/vsr/fdie_arch.py
0 → 100644
View file @
1f5da520
import
torch
import
torch.nn
as
nn
from
opensora.models.vsr.safmn_arch
import
SAFMN
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
timm.models.vision_transformer
import
Mlp
from
opensora.models.layers.blocks
import
(
Attention
,
MultiHeadCrossAttention
,
PatchEmbed3D
,
get_1d_sincos_pos_embed
,
get_2d_sincos_pos_embed
,
get_layernorm
,
)
# high pass filter
def
high_pass_filter
(
x
,
kernel_size
=
21
):
"""
对输入张量进行高通滤波,提取高频和低频部分。
参数:
x (torch.Tensor): 形状为 [B, C, T, H, W] 的输入张量,值范围在 [-1, 1]。
kernel_size (int): 高斯核的大小。
返回:
high_freq (torch.Tensor): 高频部分,形状与 x 相同。
low_freq (torch.Tensor): 低频部分,形状与 x 相同。
"""
# 计算sigma值
sigma
=
kernel_size
/
6
# 确定输入张量的设备
device
,
dtype
=
x
.
device
,
x
.
dtype
# 转换维度 [B, C, T, H, W] -> [B*T, C, H, W]
B
,
C
,
T
,
H
,
W
=
x
.
shape
x_reshaped
=
x
.
contiguous
().
view
(
B
*
T
,
C
,
H
,
W
)
# 创建高斯核
def
get_gaussian_kernel
(
kernel_size
,
sigma
):
axis
=
torch
.
arange
(
kernel_size
,
dtype
=
dtype
,
device
=
device
)
-
kernel_size
//
2
gaussian
=
torch
.
exp
(
-
0.5
*
(
axis
/
sigma
)
**
2
)
gaussian
/=
gaussian
.
sum
()
return
gaussian
gaussian_1d
=
get_gaussian_kernel
(
kernel_size
,
sigma
)
gaussian_2d
=
torch
.
outer
(
gaussian_1d
,
gaussian_1d
)
gaussian_3d
=
gaussian_2d
.
unsqueeze
(
0
).
unsqueeze
(
0
)
# [1, 1, H, W]
# 将高斯核扩展到四维
gaussian_kernel
=
gaussian_3d
.
expand
(
C
,
1
,
kernel_size
,
kernel_size
)
# 使用F.conv2d进行卷积操作
padding
=
kernel_size
//
2
# 计算低频部分
low_freq_reshaped
=
F
.
conv2d
(
x_reshaped
,
gaussian_kernel
,
padding
=
padding
,
groups
=
C
)
# 计算高频部分
high_freq_reshaped
=
x_reshaped
-
low_freq_reshaped
# 转换回原始维度 [B*T, C, H, W] -> [B, C, T, H, W]
low_freq
=
low_freq_reshaped
.
view
(
B
,
C
,
T
,
H
,
W
)
high_freq
=
high_freq_reshaped
.
view
(
B
,
C
,
T
,
H
,
W
)
return
high_freq
,
low_freq
# depth-wise separable convoluiton
class
DepthWiseSeparableResBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
):
super
(
DepthWiseSeparableResBlock
,
self
).
__init__
()
self
.
dwconv1
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
,
stride
,
padding
,
bias
=
bias
)
# groups=in_channels,
# self.conv1 = nn.Conv2d(in_channels, in_channels, 1, bias=bias)
self
.
gelu
=
nn
.
GELU
()
self
.
dwconv2
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
,
stride
,
padding
,
bias
=
bias
)
# groups=in_channels,
# self.conv2 = nn.Conv2d(in_channels, in_channels, 1, bias=bias)
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
dwconv1
(
x
)
# out = self.conv1(out)
out
=
self
.
gelu
(
out
)
out
=
self
.
dwconv2
(
out
)
# out = self.conv2(out)
out
+=
residual
return
out
# temporal transformer block
class
TemporalTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TemporalTransformerBlock
,
self
).
__init__
()
# temporal norm
self
.
temporal_norm
=
get_layernorm
(
1152
,
eps
=
1e-6
,
affine
=
False
,
use_kernel
=
True
)
# temporal self-attention
self
.
temporal_attn
=
Attention
(
dim
=
1152
,
num_heads
=
16
,
qkv_bias
=
True
,
enable_flashattn
=
True
)
# ffn
self
.
temporal_ffn
=
Mlp
(
in_features
=
1152
,
hidden_features
=
4608
,
out_features
=
1152
,
act_layer
=
nn
.
GELU
)
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
temporal_norm
(
x
)
out
=
self
.
temporal_attn
(
out
)
out
=
self
.
temporal_ffn
(
out
)
out
+=
residual
return
out
# frequency-decoupled information extractor
class
FrequencyDecoupledInfoExtractor
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
hidden_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
):
super
(
FrequencyDecoupledInfoExtractor
,
self
).
__init__
()
### spatial branch ###
self
.
safmn
=
SAFMN
(
dim
=
128
,
n_blocks
=
16
,
ffn_scale
=
2.0
,
upscaling_factor
=
4
,
use_res
=
True
)
state_dict
=
torch
.
load
(
'/mnt/bn/videodataset/VSR/pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth'
)
self
.
safmn
.
load_state_dict
(
state_dict
[
'params_ema'
],
strict
=
True
)
# high-frequency branch
# self.hf_convin = nn.Conv2d(in_channels, hidden_channels, kernel_size, stride, padding, bias=bias)
# self.hf_convout = nn.Conv2d(hidden_channels, in_channels, kernel_size, stride, padding, bias=bias)
# hf_layer = []
# for i in range(8):
# hf_layer.append(DepthWiseSeparableResBlock(hidden_channels, kernel_size, stride=1, padding=1, bias=bias))
# self.hf_body = nn.Sequential(*hf_layer)
self
.
safmn1
=
SAFMN
(
dim
=
72
,
n_blocks
=
8
,
ffn_scale
=
2.0
,
upscaling_factor
=
1
,
in_dim
=
6
,
use_res
=
True
)
# low-frequency branch
# self.lf_convin = nn.Conv2d(in_channels, hidden_channels, kernel_size, stride, padding, bias=bias)
# self.lf_convout = nn.Conv2d(hidden_channels, in_channels, kernel_size, stride, padding, bias=bias)
# lf_layer = []
# for i in range(8):
# lf_layer.append(DepthWiseSeparableResBlock(hidden_channels, kernel_size, stride=1, padding=1, bias=bias))
# self.lf_body = nn.Sequential(*lf_layer)
self
.
safmn2
=
SAFMN
(
dim
=
72
,
n_blocks
=
8
,
ffn_scale
=
2.0
,
upscaling_factor
=
1
,
in_dim
=
6
,
use_res
=
True
)
### temporal branch ###
layer
=
[]
for
i
in
range
(
3
):
layer
.
append
(
TemporalTransformerBlock
())
self
.
temporal_body
=
nn
.
Sequential
(
*
layer
)
def
get_temporal_pos_embed
(
self
):
pos_embed
=
get_1d_sincos_pos_embed
(
embed_dim
=
1152
,
length
=
16
,
scale
=
1.0
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
).
requires_grad_
(
False
)
return
pos_embed
def
spatial_forward
(
self
,
x
):
with
torch
.
no_grad
():
x
=
rearrange
(
x
,
'B C T H W -> (B T) C H W'
)
x
=
F
.
interpolate
(
x
,
scale_factor
=
1
/
4
,
mode
=
'bilinear'
)
clean_image
=
self
.
safmn
(
x
)
clean_image
=
rearrange
(
clean_image
,
'(B T) C H W -> B C T H W'
,
T
=
16
)
high_freq
,
low_freq
=
high_pass_filter
(
clean_image
)
fea_decouple
=
torch
.
cat
([
high_freq
,
low_freq
],
dim
=
1
)
fea_decouple
=
rearrange
(
fea_decouple
,
'B C T H W -> (B T) C H W'
)
# high-frequency branch
# hf_out = self.hf_convin(high_freq)
# hf_out = self.hf_body(hf_out)
# hf_out = self.hf_convout(hf_out) + high_freq
hf_out
=
self
.
safmn1
(
fea_decouple
)
hf_out
=
rearrange
(
hf_out
,
'(B T) C H W -> B C T H W'
,
T
=
16
)
# low-frequency branch
# lf_out = self.lf_convin(low_freq)
# lf_out = self.lf_body(lf_out)
# lf_out = self.lf_convout(lf_out) + low_freq
lf_out
=
self
.
safmn2
(
fea_decouple
)
lf_out
=
rearrange
(
lf_out
,
'(B T) C H W -> B C T H W'
,
T
=
16
)
return
clean_image
,
hf_out
,
lf_out
def
temporal_forward
(
self
,
x
):
x
=
rearrange
(
x
,
"B (T S) C -> (B S) T C"
,
T
=
16
)
tpe
=
self
.
get_temporal_pos_embed
().
to
(
x
.
device
,
x
.
dtype
)
x
=
x
+
tpe
x
=
self
.
temporal_body
(
x
)
x
=
rearrange
(
x
,
"(B S) T C -> B (T S) C"
,
S
=
256
)
return
x
utils_data/opensora/models/vsr/safmn_arch.py
0 → 100644
View file @
1f5da520
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torchvision
import
ops
# from basicsr.utils.registry import ARCH_REGISTRY
# Layer Norm
class
LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-6
,
data_format
=
"channels_first"
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
normalized_shape
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
normalized_shape
))
self
.
eps
=
eps
self
.
data_format
=
data_format
if
self
.
data_format
not
in
[
"channels_last"
,
"channels_first"
]:
raise
NotImplementedError
self
.
normalized_shape
=
(
normalized_shape
,
)
def
forward
(
self
,
x
):
if
self
.
data_format
==
"channels_last"
:
return
F
.
layer_norm
(
x
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
elif
self
.
data_format
==
"channels_first"
:
u
=
x
.
mean
(
1
,
keepdim
=
True
)
s
=
(
x
-
u
).
pow
(
2
).
mean
(
1
,
keepdim
=
True
)
x
=
(
x
-
u
)
/
torch
.
sqrt
(
s
+
self
.
eps
)
x
=
self
.
weight
[:,
None
,
None
]
*
x
+
self
.
bias
[:,
None
,
None
]
return
x
# SE
class
SqueezeExcitation
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
shrinkage_rate
=
0.25
):
super
().
__init__
()
hidden_dim
=
int
(
dim
*
shrinkage_rate
)
self
.
gate
=
nn
.
Sequential
(
nn
.
AdaptiveAvgPool2d
(
1
),
nn
.
Conv2d
(
dim
,
hidden_dim
,
1
,
1
,
0
),
nn
.
GELU
(),
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
,
1
,
0
),
nn
.
Sigmoid
(),
)
def
forward
(
self
,
x
):
return
x
*
self
.
gate
(
x
)
# Channel MLP: Conv1*1 -> Conv1*1
class
ChannelMLP
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
growth_rate
=
2.0
):
super
().
__init__
()
hidden_dim
=
int
(
dim
*
growth_rate
)
self
.
mlp
=
nn
.
Sequential
(
nn
.
Conv2d
(
dim
,
hidden_dim
,
1
,
1
,
0
),
nn
.
GELU
(),
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
,
1
,
0
)
)
def
forward
(
self
,
x
):
return
self
.
mlp
(
x
)
# MBConv: Conv1*1 -> DW Conv3*3 -> [SE] -> Conv1*1
class
MBConv
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
growth_rate
=
2.0
):
super
().
__init__
()
hidden_dim
=
int
(
dim
*
growth_rate
)
self
.
mbconv
=
nn
.
Sequential
(
nn
.
Conv2d
(
dim
,
hidden_dim
,
1
,
1
,
0
),
nn
.
GELU
(),
nn
.
Conv2d
(
hidden_dim
,
hidden_dim
,
3
,
1
,
1
,
groups
=
hidden_dim
),
nn
.
GELU
(),
SqueezeExcitation
(
hidden_dim
),
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
,
1
,
0
)
)
def
forward
(
self
,
x
):
return
self
.
mbconv
(
x
)
# CCM
class
CCM
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
growth_rate
=
2.0
):
super
().
__init__
()
hidden_dim
=
int
(
dim
*
growth_rate
)
self
.
ccm
=
nn
.
Sequential
(
nn
.
Conv2d
(
dim
,
hidden_dim
,
3
,
1
,
1
),
nn
.
GELU
(),
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
,
1
,
0
)
)
def
forward
(
self
,
x
):
return
self
.
ccm
(
x
)
# SAFM
class
SAFM
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_levels
=
4
):
super
().
__init__
()
self
.
n_levels
=
n_levels
chunk_dim
=
dim
//
n_levels
# Spatial Weighting
self
.
mfr
=
nn
.
ModuleList
([
nn
.
Conv2d
(
chunk_dim
,
chunk_dim
,
3
,
1
,
1
,
groups
=
chunk_dim
)
for
i
in
range
(
self
.
n_levels
)])
# # Feature Aggregation
self
.
aggr
=
nn
.
Conv2d
(
dim
,
dim
,
1
,
1
,
0
)
# Activation
self
.
act
=
nn
.
GELU
()
def
forward
(
self
,
x
):
h
,
w
=
x
.
size
()[
-
2
:]
xc
=
x
.
chunk
(
self
.
n_levels
,
dim
=
1
)
out
=
[]
for
i
in
range
(
self
.
n_levels
):
if
i
>
0
:
p_size
=
(
h
//
2
**
i
,
w
//
2
**
i
)
s
=
F
.
adaptive_max_pool2d
(
xc
[
i
],
p_size
)
s
=
self
.
mfr
[
i
](
s
)
s
=
F
.
interpolate
(
s
,
size
=
(
h
,
w
),
mode
=
'nearest'
)
else
:
s
=
self
.
mfr
[
i
](
xc
[
i
])
out
.
append
(
s
)
out
=
self
.
aggr
(
torch
.
cat
(
out
,
dim
=
1
))
out
=
self
.
act
(
out
)
*
x
return
out
class
AttBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
ffn_scale
=
2.0
):
super
().
__init__
()
self
.
norm1
=
LayerNorm
(
dim
)
self
.
norm2
=
LayerNorm
(
dim
)
# Multiscale Block
self
.
safm
=
SAFM
(
dim
)
# Feedforward layer
self
.
ccm
=
CCM
(
dim
,
ffn_scale
)
def
forward
(
self
,
x
):
x
=
self
.
safm
(
self
.
norm1
(
x
))
+
x
x
=
self
.
ccm
(
self
.
norm2
(
x
))
+
x
return
x
# @ARCH_REGISTRY.register()
class
SAFMN
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_blocks
=
8
,
ffn_scale
=
2.0
,
upscaling_factor
=
4
,
in_dim
=
3
,
use_res
=
True
):
super
().
__init__
()
self
.
use_res
=
use_res
self
.
to_feat
=
nn
.
Conv2d
(
in_dim
,
dim
,
3
,
1
,
1
)
self
.
feats
=
nn
.
Sequential
(
*
[
AttBlock
(
dim
,
ffn_scale
)
for
_
in
range
(
n_blocks
)])
self
.
to_img
=
nn
.
Sequential
(
nn
.
Conv2d
(
dim
,
3
*
upscaling_factor
**
2
,
3
,
1
,
1
),
nn
.
PixelShuffle
(
upscaling_factor
)
)
def
forward
(
self
,
x
):
x
=
self
.
to_feat
(
x
)
if
self
.
use_res
:
x
=
self
.
feats
(
x
)
+
x
else
:
x
=
self
.
feats
(
x
)
x
=
self
.
to_img
(
x
)
return
x
# if __name__== '__main__':
# #############Test Model Complexity #############
# from fvcore.nn import flop_count_table, FlopCountAnalysis, ActivationCountAnalysis
# # x = torch.randn(1, 3, 640, 360)
# # x = torch.randn(1, 3, 427, 240)
# x = torch.randn(1, 3, 320, 180)
# # x = torch.randn(1, 3, 256, 256)
# model = SAFMN(dim=36, n_blocks=8, ffn_scale=2.0, upscaling_factor=4)
# # model = SAFMN(dim=36, n_blocks=12, ffn_scale=2.0, upscaling_factor=2)
# print(model)
# print(f'params: {sum(map(lambda x: x.numel(), model.parameters()))}')
# print(flop_count_table(FlopCountAnalysis(model, x), activations=ActivationCountAnalysis(model, x)))
# output = model(x)
# print(output.shape)
\ No newline at end of file
utils_data/opensora/models/vsr/sfr_lftg.py
0 → 100644
View file @
1f5da520
import
torch
import
torch.nn
as
nn
import
xformers.ops
# spatial feature refiner
class
SpatialFeatureRefiner
(
nn
.
Module
):
def
__init__
(
self
,
hidden_channels
):
super
(
SpatialFeatureRefiner
,
self
).
__init__
()
# high-frequency branch
self
.
hf_linear
=
nn
.
Linear
(
hidden_channels
,
hidden_channels
*
2
)
# low-frequency branch
self
.
lf_linear
=
nn
.
Linear
(
hidden_channels
,
hidden_channels
*
2
)
# fusion
self
.
gelu
=
nn
.
GELU
()
self
.
fusion_linear
=
nn
.
Linear
(
hidden_channels
*
2
,
hidden_channels
)
def
forward
(
self
,
hf_feature
,
lf_feature
,
x
):
# high-frequency branch
hf_feature
=
self
.
hf_linear
(
hf_feature
)
scale_hf
,
shift_hf
=
hf_feature
.
chunk
(
2
,
dim
=-
1
)
x_hf
=
x
*
scale_hf
+
shift_hf
# low-frequency branch
lf_feature
=
self
.
lf_linear
(
lf_feature
)
scale_lf
,
shift_lf
=
lf_feature
.
chunk
(
2
,
dim
=-
1
)
x_lf
=
x
*
scale_lf
+
shift_lf
# fusion
x_fusion
=
torch
.
cat
([
x_hf
,
x_lf
],
dim
=-
1
)
x_fusion
=
self
.
gelu
(
x_fusion
)
x_fusion
=
self
.
fusion_linear
(
x_fusion
)
return
x_fusion
# low-frequency temporal guider
class
LFTemporalGuider
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_heads
,
attn_drop
=
0.0
,
proj_drop
=
0.0
):
super
(
LFTemporalGuider
,
self
).
__init__
()
assert
d_model
%
num_heads
==
0
,
"d_model must be divisible by num_heads"
self
.
d_model
=
d_model
self
.
num_heads
=
num_heads
self
.
head_dim
=
d_model
//
num_heads
self
.
q_linear
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
kv_linear
=
nn
.
Linear
(
d_model
,
d_model
*
2
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
,
cond
,
mask
=
None
):
# query/value: img tokens; key: condition; mask: if padding tokens
B
,
N
,
C
=
x
.
shape
q
=
self
.
q_linear
(
x
).
view
(
1
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
kv
=
self
.
kv_linear
(
cond
).
view
(
1
,
-
1
,
2
,
self
.
num_heads
,
self
.
head_dim
)
k
,
v
=
kv
.
unbind
(
2
)
attn_bias
=
None
if
mask
is
not
None
:
attn_bias
=
xformers
.
ops
.
fmha
.
BlockDiagonalMask
.
from_seqlens
([
N
]
*
B
,
mask
)
x
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
p
=
self
.
attn_drop
.
p
,
attn_bias
=
attn_bias
)
x
=
x
.
view
(
B
,
-
1
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
\ No newline at end of file
utils_data/opensora/registry.py
0 → 100644
View file @
1f5da520
from
copy
import
deepcopy
import
torch.nn
as
nn
from
mmengine.registry
import
Registry
def
build_module
(
module
,
builder
,
**
kwargs
):
"""Build module from config or return the module itself.
Args:
module (Union[dict, nn.Module]): The module to build.
builder (Registry): The registry to build module.
*args, **kwargs: Arguments passed to build function.
Returns:
Any: The built module.
"""
if
isinstance
(
module
,
dict
):
cfg
=
deepcopy
(
module
)
for
k
,
v
in
kwargs
.
items
():
cfg
[
k
]
=
v
return
builder
.
build
(
cfg
)
elif
isinstance
(
module
,
nn
.
Module
):
return
module
elif
module
is
None
:
return
None
else
:
raise
TypeError
(
f
"Only support dict and nn.Module, but got
{
type
(
module
)
}
."
)
MODELS
=
Registry
(
"model"
,
locations
=
[
"opensora.models"
],
)
SCHEDULERS
=
Registry
(
"scheduler"
,
locations
=
[
"opensora.schedulers"
],
)
DATASETS
=
Registry
(
"dataset"
,
locations
=
[
"opensora.datasets"
],
)
Prev
1
…
9
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment