Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
STAR
Commits
1f5da520
Commit
1f5da520
authored
Dec 05, 2025
by
yangzhong
Browse files
git init
parents
Pipeline
#3144
failed with stages
in 0 seconds
Changes
326
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2247 additions
and
0 deletions
+2247
-0
cogvideox-based/sat/diffusion_video.py
cogvideox-based/sat/diffusion_video.py
+380
-0
cogvideox-based/sat/dit_video_concat.py
cogvideox-based/sat/dit_video_concat.py
+818
-0
cogvideox-based/sat/inference_sr.sh
cogvideox-based/sat/inference_sr.sh
+13
-0
cogvideox-based/sat/requirements.txt
cogvideox-based/sat/requirements.txt
+17
-0
cogvideox-based/sat/sample_sr.py
cogvideox-based/sat/sample_sr.py
+268
-0
cogvideox-based/sat/sgm/__init__.py
cogvideox-based/sat/sgm/__init__.py
+4
-0
cogvideox-based/sat/sgm/__pycache__/__init__.cpython-39.pyc
cogvideox-based/sat/sgm/__pycache__/__init__.cpython-39.pyc
+0
-0
cogvideox-based/sat/sgm/__pycache__/util.cpython-39.pyc
cogvideox-based/sat/sgm/__pycache__/util.cpython-39.pyc
+0
-0
cogvideox-based/sat/sgm/__pycache__/webds.cpython-39.pyc
cogvideox-based/sat/sgm/__pycache__/webds.cpython-39.pyc
+0
-0
cogvideox-based/sat/sgm/lr_scheduler.py
cogvideox-based/sat/sgm/lr_scheduler.py
+110
-0
cogvideox-based/sat/sgm/models/__init__.py
cogvideox-based/sat/sgm/models/__init__.py
+1
-0
cogvideox-based/sat/sgm/models/__pycache__/__init__.cpython-39.pyc
...-based/sat/sgm/models/__pycache__/__init__.cpython-39.pyc
+0
-0
cogvideox-based/sat/sgm/models/__pycache__/autoencoder.cpython-39.pyc
...sed/sat/sgm/models/__pycache__/autoencoder.cpython-39.pyc
+0
-0
cogvideox-based/sat/sgm/models/autoencoder.py
cogvideox-based/sat/sgm/models/autoencoder.py
+630
-0
cogvideox-based/sat/sgm/modules/__init__.py
cogvideox-based/sat/sgm/modules/__init__.py
+6
-0
cogvideox-based/sat/sgm/modules/__pycache__/__init__.cpython-39.pyc
...based/sat/sgm/modules/__pycache__/__init__.cpython-39.pyc
+0
-0
cogvideox-based/sat/sgm/modules/__pycache__/attention.cpython-39.pyc
...ased/sat/sgm/modules/__pycache__/attention.cpython-39.pyc
+0
-0
cogvideox-based/sat/sgm/modules/__pycache__/cp_enc_dec.cpython-39.pyc
...sed/sat/sgm/modules/__pycache__/cp_enc_dec.cpython-39.pyc
+0
-0
cogvideox-based/sat/sgm/modules/__pycache__/ema.cpython-39.pyc
...deox-based/sat/sgm/modules/__pycache__/ema.cpython-39.pyc
+0
-0
cogvideox-based/sat/sgm/modules/__pycache__/fuse_sft_block.cpython-39.pyc
...sat/sgm/modules/__pycache__/fuse_sft_block.cpython-39.pyc
+0
-0
No files found.
cogvideox-based/sat/diffusion_video.py
0 → 100644
View file @
1f5da520
import
math
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
,
Optional
from
omegaconf
import
ListConfig
,
OmegaConf
from
copy
import
deepcopy
import
torch.nn.functional
as
F
from
sat.helpers
import
print_rank0
import
torch
from
torch
import
nn
from
sgm.modules
import
UNCONDITIONAL_CONFIG
from
sgm.modules.autoencoding.temporal_ae
import
VideoDecoder
from
sgm.modules.diffusionmodules.wrappers
import
OPENAIUNETWRAPPER
from
sgm.util
import
(
default
,
disabled_train
,
get_obj_from_str
,
instantiate_from_config
,
log_txt_as_img
,
)
import
gc
from
sat
import
mpu
import
random
class
SATVideoDiffusionEngine
(
nn
.
Module
):
def
__init__
(
self
,
args
,
**
kwargs
):
super
().
__init__
()
model_config
=
args
.
model_config
# model args preprocess
log_keys
=
model_config
.
get
(
"log_keys"
,
None
)
input_key
=
model_config
.
get
(
"input_key"
,
"mp4"
)
network_config
=
model_config
.
get
(
"network_config"
,
None
)
network_wrapper
=
model_config
.
get
(
"network_wrapper"
,
None
)
denoiser_config
=
model_config
.
get
(
"denoiser_config"
,
None
)
sampler_config
=
model_config
.
get
(
"sampler_config"
,
None
)
conditioner_config
=
model_config
.
get
(
"conditioner_config"
,
None
)
first_stage_config
=
model_config
.
get
(
"first_stage_config"
,
None
)
loss_fn_config
=
model_config
.
get
(
"loss_fn_config"
,
None
)
scale_factor
=
model_config
.
get
(
"scale_factor"
,
1.0
)
latent_input
=
model_config
.
get
(
"latent_input"
,
False
)
disable_first_stage_autocast
=
model_config
.
get
(
"disable_first_stage_autocast"
,
False
)
no_cond_log
=
model_config
.
get
(
"disable_first_stage_autocast"
,
False
)
not_trainable_prefixes
=
model_config
.
get
(
"not_trainable_prefixes"
,
[
"first_stage_model"
,
"conditioner"
])
compile_model
=
model_config
.
get
(
"compile_model"
,
False
)
en_and_decode_n_samples_a_time
=
model_config
.
get
(
"en_and_decode_n_samples_a_time"
,
None
)
lr_scale
=
model_config
.
get
(
"lr_scale"
,
None
)
lora_train
=
model_config
.
get
(
"lora_train"
,
False
)
self
.
use_pd
=
model_config
.
get
(
"use_pd"
,
False
)
# progressive distillation
self
.
log_keys
=
log_keys
self
.
input_key
=
input_key
self
.
not_trainable_prefixes
=
not_trainable_prefixes
self
.
en_and_decode_n_samples_a_time
=
en_and_decode_n_samples_a_time
self
.
lr_scale
=
lr_scale
self
.
lora_train
=
lora_train
self
.
noised_image_input
=
model_config
.
get
(
"noised_image_input"
,
False
)
self
.
noised_image_all_concat
=
model_config
.
get
(
"noised_image_all_concat"
,
False
)
self
.
noised_image_dropout
=
model_config
.
get
(
"noised_image_dropout"
,
0.0
)
if
args
.
fp16
:
dtype
=
torch
.
float16
dtype_str
=
"fp16"
elif
args
.
bf16
:
dtype
=
torch
.
bfloat16
dtype_str
=
"bf16"
else
:
dtype
=
torch
.
float32
dtype_str
=
"fp32"
self
.
dtype
=
dtype
self
.
dtype_str
=
dtype_str
network_config
[
"params"
][
"dtype"
]
=
dtype_str
model
=
instantiate_from_config
(
network_config
)
self
.
model
=
get_obj_from_str
(
default
(
network_wrapper
,
OPENAIUNETWRAPPER
))(
model
,
compile_model
=
compile_model
,
dtype
=
dtype
)
self
.
denoiser
=
instantiate_from_config
(
denoiser_config
)
self
.
sampler
=
instantiate_from_config
(
sampler_config
)
if
sampler_config
is
not
None
else
None
self
.
conditioner
=
instantiate_from_config
(
default
(
conditioner_config
,
UNCONDITIONAL_CONFIG
))
self
.
_init_first_stage
(
first_stage_config
)
self
.
loss_fn
=
instantiate_from_config
(
loss_fn_config
)
if
loss_fn_config
is
not
None
else
None
self
.
latent_input
=
latent_input
self
.
scale_factor
=
scale_factor
self
.
disable_first_stage_autocast
=
disable_first_stage_autocast
self
.
no_cond_log
=
no_cond_log
self
.
device
=
args
.
device
def
disable_untrainable_params
(
self
):
total_trainable
=
0
for
n
,
p
in
self
.
named_parameters
():
if
p
.
requires_grad
==
False
:
continue
flag
=
False
for
prefix
in
self
.
not_trainable_prefixes
:
if
n
.
startswith
(
prefix
)
or
prefix
==
"all"
:
flag
=
True
break
lora_prefix
=
[
"matrix_A"
,
"matrix_B"
,
'final_layer'
,
'proj_sr'
,
'local'
]
for
prefix
in
lora_prefix
:
if
prefix
in
n
:
flag
=
False
break
if
flag
:
p
.
requires_grad_
(
False
)
else
:
print
(
n
)
total_trainable
+=
p
.
numel
()
print_rank0
(
"***** Total trainable parameters: "
+
str
(
total_trainable
/
1000000
)
+
"M *****"
)
def
reinit
(
self
,
parent_model
=
None
):
# reload the initial params from previous trained modules
# you can also get access to other mixins through parent_model.get_mixin().
pass
def
_init_first_stage
(
self
,
config
):
model
=
instantiate_from_config
(
config
).
eval
()
model
.
train
=
disabled_train
for
param
in
model
.
parameters
():
param
.
requires_grad
=
False
self
.
first_stage_model
=
model
def
forward
(
self
,
x
,
hq_video
,
batch
):
loss
=
self
.
loss_fn
(
self
.
model
,
self
.
denoiser
,
self
.
conditioner
,
x
,
batch
,
hq_video
,
self
.
decode_first_stage
)
loss_mean
=
loss
.
mean
()
loss_dict
=
{
"loss"
:
loss_mean
}
return
loss_mean
,
loss_dict
def
shared_step
(
self
,
batch
:
Dict
)
->
Any
:
x
=
self
.
get_input
(
batch
)
if
self
.
lr_scale
is
not
None
:
lr_x
=
F
.
interpolate
(
x
,
scale_factor
=
1
/
self
.
lr_scale
,
mode
=
"bilinear"
,
align_corners
=
False
)
lr_x
=
F
.
interpolate
(
lr_x
,
scale_factor
=
self
.
lr_scale
,
mode
=
"bilinear"
,
align_corners
=
False
)
lr_z
=
self
.
encode_first_stage
(
lr_x
,
batch
)
batch
[
"lr_input"
]
=
lr_z
x
=
x
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
# (B, T, C, H, W) -> (B, C, T, H, W)
hq_video
=
x
# (B, C, T, H, W)
x
=
self
.
encode_first_stage
(
x
,
batch
)
x
=
x
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
# (B, C, T, H, W) -> (B, T, C, H, W)
if
'lq'
in
batch
.
keys
():
# print('LQ is NOT None')
lq
=
batch
[
'lq'
].
to
(
self
.
dtype
)
lq
=
lq
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
lq
=
self
.
encode_first_stage
(
lq
,
batch
)
lq
=
lq
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
batch
[
'lq'
]
=
lq
# Uncomment for t2v training,
# batch['lq'] = None
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
loss
,
loss_dict
=
self
(
x
,
hq_video
,
batch
)
return
loss
,
loss_dict
def
get_input
(
self
,
batch
):
return
batch
[
self
.
input_key
].
to
(
self
.
dtype
)
@
torch
.
no_grad
()
def
decode_first_stage
(
self
,
z
):
z
=
1.0
/
self
.
scale_factor
*
z
n_samples
=
default
(
self
.
en_and_decode_n_samples_a_time
,
z
.
shape
[
0
])
n_rounds
=
math
.
ceil
(
z
.
shape
[
0
]
/
n_samples
)
all_out
=
[]
with
torch
.
autocast
(
"cuda"
,
enabled
=
not
self
.
disable_first_stage_autocast
):
for
n
in
range
(
n_rounds
):
if
isinstance
(
self
.
first_stage_model
.
decoder
,
VideoDecoder
):
kwargs
=
{
"timesteps"
:
len
(
z
[
n
*
n_samples
:
(
n
+
1
)
*
n_samples
])}
else
:
kwargs
=
{}
use_cp
=
False
out
=
self
.
first_stage_model
.
decode
(
z
[
n
*
n_samples
:
(
n
+
1
)
*
n_samples
],
**
kwargs
)
all_out
.
append
(
out
)
out
=
torch
.
cat
(
all_out
,
dim
=
0
)
return
out
@
torch
.
no_grad
()
def
encode_first_stage
(
self
,
x
,
batch
=
None
):
frame
=
x
.
shape
[
2
]
if
frame
>
1
and
self
.
latent_input
:
x
=
x
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
return
x
*
self
.
scale_factor
# already encoded
use_cp
=
False
n_samples
=
default
(
self
.
en_and_decode_n_samples_a_time
,
x
.
shape
[
0
])
n_rounds
=
math
.
ceil
(
x
.
shape
[
0
]
/
n_samples
)
all_out
=
[]
with
torch
.
autocast
(
"cuda"
,
enabled
=
not
self
.
disable_first_stage_autocast
):
for
n
in
range
(
n_rounds
):
out
=
self
.
first_stage_model
.
encode
(
x
[
n
*
n_samples
:
(
n
+
1
)
*
n_samples
])
all_out
.
append
(
out
)
z
=
torch
.
cat
(
all_out
,
dim
=
0
)
z
=
self
.
scale_factor
*
z
return
z
@
torch
.
no_grad
()
def
sample
(
self
,
cond
:
Dict
,
uc
:
Union
[
Dict
,
None
]
=
None
,
batch_size
:
int
=
16
,
shape
:
Union
[
None
,
Tuple
,
List
]
=
None
,
prefix
=
None
,
concat_images
=
None
,
**
kwargs
,
):
randn
=
torch
.
randn
(
batch_size
,
*
shape
).
to
(
torch
.
float32
).
to
(
self
.
device
)
if
hasattr
(
self
,
"seeded_noise"
):
randn
=
self
.
seeded_noise
(
randn
)
if
prefix
is
not
None
:
randn
=
torch
.
cat
([
prefix
,
randn
[:,
prefix
.
shape
[
1
]
:]],
dim
=
1
)
# broadcast noise
mp_size
=
mpu
.
get_model_parallel_world_size
()
if
mp_size
>
1
:
global_rank
=
torch
.
distributed
.
get_rank
()
//
mp_size
src
=
global_rank
*
mp_size
torch
.
distributed
.
broadcast
(
randn
,
src
=
src
,
group
=
mpu
.
get_model_parallel_group
())
scale
=
None
scale_emb
=
None
denoiser
=
lambda
input
,
sigma
,
c
,
**
addtional_model_inputs
:
self
.
denoiser
(
self
.
model
,
input
,
sigma
,
c
,
concat_images
=
concat_images
,
**
addtional_model_inputs
)
samples
=
self
.
sampler
(
denoiser
,
randn
,
cond
,
uc
=
uc
,
scale
=
scale
,
scale_emb
=
scale_emb
)
samples
=
samples
.
to
(
self
.
dtype
)
return
samples
@
torch
.
no_grad
()
def
sample_sr
(
self
,
cond
:
Dict
,
uc
:
Union
[
Dict
,
None
]
=
None
,
batch_size
:
int
=
16
,
shape
:
Union
[
None
,
Tuple
,
List
]
=
None
,
lq
=
None
,
prefix
=
None
,
concat_images
=
None
,
**
kwargs
,
):
randn
=
torch
.
randn
(
batch_size
,
*
shape
).
to
(
torch
.
float32
).
to
(
self
.
device
)
if
hasattr
(
self
,
"seeded_noise"
):
randn
=
self
.
seeded_noise
(
randn
)
if
prefix
is
not
None
:
randn
=
torch
.
cat
([
prefix
,
randn
[:,
prefix
.
shape
[
1
]
:]],
dim
=
1
)
# broadcast noise
mp_size
=
mpu
.
get_model_parallel_world_size
()
if
mp_size
>
1
:
global_rank
=
torch
.
distributed
.
get_rank
()
//
mp_size
src
=
global_rank
*
mp_size
torch
.
distributed
.
broadcast
(
randn
,
src
=
src
,
group
=
mpu
.
get_model_parallel_group
())
scale
=
None
scale_emb
=
None
denoiser
=
lambda
input
,
sigma
,
c
,
**
addtional_model_inputs
:
self
.
denoiser
(
self
.
model
,
input
,
sigma
,
c
,
concat_images
=
concat_images
,
**
addtional_model_inputs
)
# add lq condition (new)
lq
=
lq
.
to
(
randn
.
device
,
self
.
dtype
)
lq
=
lq
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
lq
=
self
.
encode_first_stage
(
lq
)
lq
=
lq
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
lq
=
torch
.
cat
((
lq
,
lq
),
dim
=
0
)
# for CFG inference
# For T2V
# lq = None
# print('randn shape:', randn.shape) # torch.Size([1, 8, 16, 60, 90])
# print('lq shape:', lq.shape) # torch.Size([1, 8, 16, 60, 90])
samples
=
self
.
sampler
(
denoiser
,
randn
,
cond
,
uc
=
uc
,
scale
=
scale
,
scale_emb
=
scale_emb
,
lq
=
lq
)
samples
=
samples
.
to
(
self
.
dtype
)
return
samples
@
torch
.
no_grad
()
def
log_conditionings
(
self
,
batch
:
Dict
,
n
:
int
)
->
Dict
:
"""
Defines heuristics to log different conditionings.
These can be lists of strings (text-to-image), tensors, ints, ...
"""
image_h
,
image_w
=
batch
[
self
.
input_key
].
shape
[
3
:]
log
=
dict
()
for
embedder
in
self
.
conditioner
.
embedders
:
if
((
self
.
log_keys
is
None
)
or
(
embedder
.
input_key
in
self
.
log_keys
))
and
not
self
.
no_cond_log
:
x
=
batch
[
embedder
.
input_key
][:
n
]
if
isinstance
(
x
,
torch
.
Tensor
):
if
x
.
dim
()
==
1
:
# class-conditional, convert integer to string
x
=
[
str
(
x
[
i
].
item
())
for
i
in
range
(
x
.
shape
[
0
])]
xc
=
log_txt_as_img
((
image_h
,
image_w
),
x
,
size
=
image_h
//
4
)
elif
x
.
dim
()
==
2
:
# size and crop cond and the like
x
=
[
"x"
.
join
([
str
(
xx
)
for
xx
in
x
[
i
].
tolist
()])
for
i
in
range
(
x
.
shape
[
0
])]
xc
=
log_txt_as_img
((
image_h
,
image_w
),
x
,
size
=
image_h
//
20
)
else
:
raise
NotImplementedError
()
elif
isinstance
(
x
,
(
List
,
ListConfig
)):
if
isinstance
(
x
[
0
],
str
):
xc
=
log_txt_as_img
((
image_h
,
image_w
),
x
,
size
=
image_h
//
20
)
else
:
raise
NotImplementedError
()
else
:
raise
NotImplementedError
()
log
[
embedder
.
input_key
]
=
xc
return
log
@
torch
.
no_grad
()
def
log_video
(
self
,
batch
:
Dict
,
N
:
int
=
8
,
ucg_keys
:
List
[
str
]
=
None
,
only_log_video_latents
=
False
,
**
kwargs
,
)
->
Dict
:
conditioner_input_keys
=
[
e
.
input_key
for
e
in
self
.
conditioner
.
embedders
]
if
ucg_keys
:
assert
all
(
map
(
lambda
x
:
x
in
conditioner_input_keys
,
ucg_keys
)),
(
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
f
"but we have
{
ucg_keys
}
vs.
{
conditioner_input_keys
}
"
)
else
:
ucg_keys
=
conditioner_input_keys
log
=
dict
()
x
=
self
.
get_input
(
batch
)
c
,
uc
=
self
.
conditioner
.
get_unconditional_conditioning
(
batch
,
force_uc_zero_embeddings
=
ucg_keys
if
len
(
self
.
conditioner
.
embedders
)
>
0
else
[],
)
sampling_kwargs
=
{}
N
=
min
(
x
.
shape
[
0
],
N
)
x
=
x
.
to
(
self
.
device
)[:
N
]
if
not
self
.
latent_input
:
log
[
"inputs"
]
=
x
.
to
(
torch
.
float32
)
x
=
x
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
z
=
self
.
encode_first_stage
(
x
,
batch
)
if
not
only_log_video_latents
:
log
[
"reconstructions"
]
=
self
.
decode_first_stage
(
z
).
to
(
torch
.
float32
)
log
[
"reconstructions"
]
=
log
[
"reconstructions"
].
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
z
=
z
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
log
.
update
(
self
.
log_conditionings
(
batch
,
N
))
for
k
in
c
:
if
isinstance
(
c
[
k
],
torch
.
Tensor
):
c
[
k
],
uc
[
k
]
=
map
(
lambda
y
:
y
[
k
][:
N
].
to
(
self
.
device
),
(
c
,
uc
))
samples
=
self
.
sample
(
c
,
shape
=
z
.
shape
[
1
:],
uc
=
uc
,
batch_size
=
N
,
**
sampling_kwargs
)
# b t c h w
samples
=
samples
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
if
only_log_video_latents
:
latents
=
1.0
/
self
.
scale_factor
*
samples
log
[
"latents"
]
=
latents
else
:
samples
=
self
.
decode_first_stage
(
samples
).
to
(
torch
.
float32
)
samples
=
samples
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
log
[
"samples"
]
=
samples
return
log
cogvideox-based/sat/dit_video_concat.py
0 → 100644
View file @
1f5da520
from
functools
import
partial
from
einops
import
rearrange
,
repeat
import
numpy
as
np
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
sat.model.base_model
import
BaseModel
,
non_conflict
from
sat.model.mixins
import
BaseMixin
from
sat.transformer_defaults
import
HOOKS_DEFAULT
,
attention_fn_default
from
sat.mpu.layers
import
ColumnParallelLinear
from
sgm.util
import
instantiate_from_config
from
sgm.modules.diffusionmodules.openaimodel
import
Timestep
from
sgm.modules.diffusionmodules.util
import
(
linear
,
timestep_embedding
,
)
from
sat.ops.layernorm
import
LayerNorm
,
RMSNorm
class
ImagePatchEmbeddingMixin
(
BaseMixin
):
def
__init__
(
self
,
in_channels
,
hidden_size
,
patch_size
,
bias
=
True
,
text_hidden_size
=
None
,
):
super
().
__init__
()
# print(in_channels)
# self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias)
self
.
proj_sr
=
nn
.
Conv2d
(
in_channels
*
2
,
hidden_size
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
bias
=
bias
)
# 复制原始层前16个通道的权重
# self.proj_sr.weight.data[:, :in_channels, :, :] = self.proj.weight.data.clone()
# # 将后16个通道的权重初始化为零
# torch.nn.init.constant_(self.proj_sr.weight.data[:, in_channels:, :, :], 0)
# # 如果使用了 bias,直接复制原有的 bias 值
# if bias:
# self.proj_sr.bias.data = self.proj.bias.data.clone()
if
text_hidden_size
is
not
None
:
self
.
text_proj
=
nn
.
Linear
(
text_hidden_size
,
hidden_size
)
else
:
self
.
text_proj
=
None
def
word_embedding_forward
(
self
,
input_ids
,
**
kwargs
):
# now is 3d patch
images
=
kwargs
[
"images"
]
# (b,t,c,h,w)
B
,
T
=
images
.
shape
[:
2
]
emb
=
images
.
view
(
-
1
,
*
images
.
shape
[
2
:])
#--------
# Debug
#--------
# emb_ori = emb
# x_ori, _ = emb.chunk(2, dim=1)
# emb = self.proj(x_ori)
# emb_debug = self.proj_sr(emb_ori) # ((b t),d,h/2,w/2) [2 * 8, 16, 60, 90]
# print(torch.sqrt((emb - emb_debug)**2).mean())
emb
=
self
.
proj_sr
(
emb
)
# ((b t),d,h/2,w/2) [2 * 8, 32, 60, 90]
emb
=
emb
.
view
(
B
,
T
,
*
emb
.
shape
[
1
:])
emb
=
emb
.
flatten
(
3
).
transpose
(
2
,
3
)
# (b,t,n,d)
emb
=
rearrange
(
emb
,
"b t n d -> b (t n) d"
)
if
self
.
text_proj
is
not
None
:
text_emb
=
self
.
text_proj
(
kwargs
[
"encoder_outputs"
])
emb
=
torch
.
cat
((
text_emb
,
emb
),
dim
=
1
)
# (b,n_t+t*n_i,d)
emb
=
emb
.
contiguous
()
return
emb
# (b,n_t+t*n_i,d)
def
reinit
(
self
,
parent_model
=
None
):
w
=
self
.
proj_sr
.
weight
.
data
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
nn
.
init
.
constant_
(
self
.
proj_sr
.
bias
,
0
)
del
self
.
transformer
.
word_embeddings
def
get_3d_sincos_pos_embed
(
embed_dim
,
grid_height
,
grid_width
,
t_size
,
cls_token
=
False
,
height_interpolation
=
1.0
,
width_interpolation
=
1.0
,
time_interpolation
=
1.0
,
):
"""
grid_size: int of the grid height and width
t_size: int of the temporal size
return:
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
assert
embed_dim
%
4
==
0
embed_dim_spatial
=
embed_dim
//
4
*
3
embed_dim_temporal
=
embed_dim
//
4
# spatial
grid_h
=
np
.
arange
(
grid_height
,
dtype
=
np
.
float32
)
/
height_interpolation
grid_w
=
np
.
arange
(
grid_width
,
dtype
=
np
.
float32
)
/
width_interpolation
grid
=
np
.
meshgrid
(
grid_w
,
grid_h
)
# here w goes first
grid
=
np
.
stack
(
grid
,
axis
=
0
)
grid
=
grid
.
reshape
([
2
,
1
,
grid_height
,
grid_width
])
pos_embed_spatial
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim_spatial
,
grid
)
# temporal
grid_t
=
np
.
arange
(
t_size
,
dtype
=
np
.
float32
)
/
time_interpolation
pos_embed_temporal
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim_temporal
,
grid_t
)
# concate: [T, H, W] order
pos_embed_temporal
=
pos_embed_temporal
[:,
np
.
newaxis
,
:]
pos_embed_temporal
=
np
.
repeat
(
pos_embed_temporal
,
grid_height
*
grid_width
,
axis
=
1
)
# [T, H*W, D // 4]
pos_embed_spatial
=
pos_embed_spatial
[
np
.
newaxis
,
:,
:]
pos_embed_spatial
=
np
.
repeat
(
pos_embed_spatial
,
t_size
,
axis
=
0
)
# [T, H*W, D // 4 * 3]
pos_embed
=
np
.
concatenate
([
pos_embed_temporal
,
pos_embed_spatial
],
axis
=-
1
)
# pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
return
pos_embed
# [T, H*W, D]
def
get_2d_sincos_pos_embed
(
embed_dim
,
grid_height
,
grid_width
,
cls_token
=
False
,
extra_tokens
=
0
):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h
=
np
.
arange
(
grid_height
,
dtype
=
np
.
float32
)
grid_w
=
np
.
arange
(
grid_width
,
dtype
=
np
.
float32
)
grid
=
np
.
meshgrid
(
grid_w
,
grid_h
)
# here w goes first
grid
=
np
.
stack
(
grid
,
axis
=
0
)
grid
=
grid
.
reshape
([
2
,
1
,
grid_height
,
grid_width
])
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
)
if
cls_token
and
extra_tokens
>
0
:
pos_embed
=
np
.
concatenate
([
np
.
zeros
([
extra_tokens
,
embed_dim
]),
pos_embed
],
axis
=
0
)
return
pos_embed
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
):
assert
embed_dim
%
2
==
0
# use half of dimensions to encode grid_h
emb_h
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
0
])
# (H*W, D/2)
emb_w
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
1
])
# (H*W, D/2)
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=
1
)
# (H*W, D)
return
emb
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
,
pos
):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float64
)
omega
/=
embed_dim
/
2.0
omega
=
1.0
/
10000
**
omega
# (D/2,)
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
np
.
einsum
(
"m,d->md"
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
return
emb
class
Basic3DPositionEmbeddingMixin
(
BaseMixin
):
def
__init__
(
self
,
height
,
width
,
compressed_num_frames
,
hidden_size
,
text_length
=
0
,
height_interpolation
=
1.0
,
width_interpolation
=
1.0
,
time_interpolation
=
1.0
,
):
super
().
__init__
()
self
.
height
=
height
self
.
width
=
width
self
.
text_length
=
text_length
self
.
compressed_num_frames
=
compressed_num_frames
self
.
spatial_length
=
height
*
width
self
.
num_patches
=
height
*
width
*
compressed_num_frames
self
.
pos_embedding
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
int
(
text_length
+
self
.
num_patches
),
int
(
hidden_size
)),
requires_grad
=
False
)
self
.
height_interpolation
=
height_interpolation
self
.
width_interpolation
=
width_interpolation
self
.
time_interpolation
=
time_interpolation
def
position_embedding_forward
(
self
,
position_ids
,
**
kwargs
):
if
kwargs
[
"images"
].
shape
[
1
]
==
1
:
return
self
.
pos_embedding
[:,
:
self
.
text_length
+
self
.
spatial_length
]
return
self
.
pos_embedding
[:,
:
self
.
text_length
+
kwargs
[
"seq_length"
]]
def
reinit
(
self
,
parent_model
=
None
):
del
self
.
transformer
.
position_embeddings
pos_embed
=
get_3d_sincos_pos_embed
(
self
.
pos_embedding
.
shape
[
-
1
],
self
.
height
,
self
.
width
,
self
.
compressed_num_frames
,
height_interpolation
=
self
.
height_interpolation
,
width_interpolation
=
self
.
width_interpolation
,
time_interpolation
=
self
.
time_interpolation
,
)
pos_embed
=
torch
.
from_numpy
(
pos_embed
).
float
()
pos_embed
=
rearrange
(
pos_embed
,
"t n d -> (t n) d"
)
self
.
pos_embedding
.
data
[:,
-
self
.
num_patches
:].
copy_
(
pos_embed
)
def
broadcat
(
tensors
,
dim
=-
1
):
num_tensors
=
len
(
tensors
)
shape_lens
=
set
(
list
(
map
(
lambda
t
:
len
(
t
.
shape
),
tensors
)))
assert
len
(
shape_lens
)
==
1
,
"tensors must all have the same number of dimensions"
shape_len
=
list
(
shape_lens
)[
0
]
dim
=
(
dim
+
shape_len
)
if
dim
<
0
else
dim
dims
=
list
(
zip
(
*
map
(
lambda
t
:
list
(
t
.
shape
),
tensors
)))
expandable_dims
=
[(
i
,
val
)
for
i
,
val
in
enumerate
(
dims
)
if
i
!=
dim
]
assert
all
(
[
*
map
(
lambda
t
:
len
(
set
(
t
[
1
]))
<=
2
,
expandable_dims
)]
),
"invalid dimensions for broadcastable concatentation"
max_dims
=
list
(
map
(
lambda
t
:
(
t
[
0
],
max
(
t
[
1
])),
expandable_dims
))
expanded_dims
=
list
(
map
(
lambda
t
:
(
t
[
0
],
(
t
[
1
],)
*
num_tensors
),
max_dims
))
expanded_dims
.
insert
(
dim
,
(
dim
,
dims
[
dim
]))
expandable_shapes
=
list
(
zip
(
*
map
(
lambda
t
:
t
[
1
],
expanded_dims
)))
tensors
=
list
(
map
(
lambda
t
:
t
[
0
].
expand
(
*
t
[
1
]),
zip
(
tensors
,
expandable_shapes
)))
return
torch
.
cat
(
tensors
,
dim
=
dim
)
def
rotate_half
(
x
):
x
=
rearrange
(
x
,
"... (d r) -> ... d r"
,
r
=
2
)
x1
,
x2
=
x
.
unbind
(
dim
=-
1
)
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
rearrange
(
x
,
"... d r -> ... (d r)"
)
class
Rotary3DPositionEmbeddingMixin
(
BaseMixin
):
def
__init__
(
self
,
height
,
width
,
compressed_num_frames
,
hidden_size
,
hidden_size_head
,
text_length
,
theta
=
10000
,
rot_v
=
False
,
learnable_pos_embed
=
False
,
):
super
().
__init__
()
self
.
rot_v
=
rot_v
dim_t
=
hidden_size_head
//
4
dim_h
=
hidden_size_head
//
8
*
3
dim_w
=
hidden_size_head
//
8
*
3
freqs_t
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim_t
,
2
)[:
(
dim_t
//
2
)].
float
()
/
dim_t
))
freqs_h
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim_h
,
2
)[:
(
dim_h
//
2
)].
float
()
/
dim_h
))
freqs_w
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim_w
,
2
)[:
(
dim_w
//
2
)].
float
()
/
dim_w
))
grid_t
=
torch
.
arange
(
compressed_num_frames
,
dtype
=
torch
.
float32
)
grid_h
=
torch
.
arange
(
height
,
dtype
=
torch
.
float32
)
grid_w
=
torch
.
arange
(
width
,
dtype
=
torch
.
float32
)
freqs_t
=
torch
.
einsum
(
"..., f -> ... f"
,
grid_t
,
freqs_t
)
freqs_h
=
torch
.
einsum
(
"..., f -> ... f"
,
grid_h
,
freqs_h
)
freqs_w
=
torch
.
einsum
(
"..., f -> ... f"
,
grid_w
,
freqs_w
)
freqs_t
=
repeat
(
freqs_t
,
"... n -> ... (n r)"
,
r
=
2
)
freqs_h
=
repeat
(
freqs_h
,
"... n -> ... (n r)"
,
r
=
2
)
freqs_w
=
repeat
(
freqs_w
,
"... n -> ... (n r)"
,
r
=
2
)
freqs
=
broadcat
((
freqs_t
[:,
None
,
None
,
:],
freqs_h
[
None
,
:,
None
,
:],
freqs_w
[
None
,
None
,
:,
:]),
dim
=-
1
)
freqs
=
rearrange
(
freqs
,
"t h w d -> (t h w) d"
)
freqs
=
freqs
.
contiguous
()
freqs_sin
=
freqs
.
sin
()
freqs_cos
=
freqs
.
cos
()
self
.
register_buffer
(
"freqs_sin"
,
freqs_sin
)
self
.
register_buffer
(
"freqs_cos"
,
freqs_cos
)
self
.
text_length
=
text_length
if
learnable_pos_embed
:
num_patches
=
height
*
width
*
compressed_num_frames
+
text_length
self
.
pos_embedding
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
int
(
hidden_size
)),
requires_grad
=
True
)
else
:
self
.
pos_embedding
=
None
def
rotary
(
self
,
t
,
**
kwargs
):
seq_len
=
t
.
shape
[
2
]
freqs_cos
=
self
.
freqs_cos
[:
seq_len
].
unsqueeze
(
0
).
unsqueeze
(
0
)
freqs_sin
=
self
.
freqs_sin
[:
seq_len
].
unsqueeze
(
0
).
unsqueeze
(
0
)
return
t
*
freqs_cos
+
rotate_half
(
t
)
*
freqs_sin
def
position_embedding_forward
(
self
,
position_ids
,
**
kwargs
):
if
self
.
pos_embedding
is
not
None
:
return
self
.
pos_embedding
[:,
:
self
.
text_length
+
kwargs
[
"seq_length"
]]
else
:
return
None
def
attention_fn
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
,
attention_dropout
=
None
,
log_attention_weights
=
None
,
scaling_attention_score
=
True
,
**
kwargs
,
):
attention_fn_default
=
HOOKS_DEFAULT
[
"attention_fn"
]
query_layer
[:,
:,
self
.
text_length
:]
=
self
.
rotary
(
query_layer
[:,
:,
self
.
text_length
:])
key_layer
[:,
:,
self
.
text_length
:]
=
self
.
rotary
(
key_layer
[:,
:,
self
.
text_length
:])
if
self
.
rot_v
:
value_layer
[:,
:,
self
.
text_length
:]
=
self
.
rotary
(
value_layer
[:,
:,
self
.
text_length
:])
return
attention_fn_default
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
,
attention_dropout
=
attention_dropout
,
log_attention_weights
=
log_attention_weights
,
scaling_attention_score
=
scaling_attention_score
,
**
kwargs
,
)
def
modulate
(
x
,
shift
,
scale
):
return
x
*
(
1
+
scale
.
unsqueeze
(
1
))
+
shift
.
unsqueeze
(
1
)
def
unpatchify
(
x
,
c
,
p
,
w
,
h
,
rope_position_ids
=
None
,
**
kwargs
):
"""
x: (N, T/2 * S, patch_size**3 * C)
imgs: (N, T, H, W, C)
"""
if
rope_position_ids
is
not
None
:
assert
NotImplementedError
# do pix2struct unpatchify
L
=
x
.
shape
[
1
]
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
L
,
p
,
p
,
c
))
x
=
torch
.
einsum
(
"nlpqc->ncplq"
,
x
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
c
,
p
,
L
*
p
))
else
:
b
=
x
.
shape
[
0
]
imgs
=
rearrange
(
x
,
"b (t h w) (c p q) -> b t c (h p) (w q)"
,
b
=
b
,
h
=
h
,
w
=
w
,
c
=
c
,
p
=
p
,
q
=
p
)
return
imgs
class
FinalLayerMixin
(
BaseMixin
):
def
__init__
(
self
,
hidden_size
,
time_embed_dim
,
patch_size
,
out_channels
,
latent_width
,
latent_height
,
elementwise_affine
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
patch_size
=
patch_size
self
.
out_channels
=
out_channels
self
.
norm_final
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
elementwise_affine
,
eps
=
1e-6
)
self
.
linear
=
nn
.
Linear
(
hidden_size
,
patch_size
*
patch_size
*
out_channels
,
bias
=
True
)
self
.
adaLN_modulation
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
2
*
hidden_size
,
bias
=
True
))
self
.
spatial_length
=
latent_width
*
latent_height
//
patch_size
**
2
self
.
latent_width
=
latent_width
self
.
latent_height
=
latent_height
def
final_forward
(
self
,
logits
,
**
kwargs
):
x
,
emb
=
logits
[:,
kwargs
[
"text_length"
]
:,
:],
kwargs
[
"emb"
]
# x:(b,(t n),d)
shift
,
scale
=
self
.
adaLN_modulation
(
emb
).
chunk
(
2
,
dim
=
1
)
x
=
modulate
(
self
.
norm_final
(
x
),
shift
,
scale
)
x
=
self
.
linear
(
x
)
return
unpatchify
(
x
,
c
=
self
.
out_channels
,
p
=
self
.
patch_size
,
w
=
self
.
latent_width
//
self
.
patch_size
,
h
=
self
.
latent_height
//
self
.
patch_size
,
rope_position_ids
=
kwargs
.
get
(
"rope_position_ids"
,
None
),
**
kwargs
,
)
def
reinit
(
self
,
parent_model
=
None
):
nn
.
init
.
xavier_uniform_
(
self
.
linear
.
weight
)
nn
.
init
.
constant_
(
self
.
linear
.
bias
,
0
)
class
SwiGLUMixin
(
BaseMixin
):
def
__init__
(
self
,
num_layers
,
in_features
,
hidden_features
,
bias
=
False
):
super
().
__init__
()
self
.
w2
=
nn
.
ModuleList
(
[
ColumnParallelLinear
(
in_features
,
hidden_features
,
gather_output
=
False
,
bias
=
bias
,
module
=
self
,
name
=
"dense_h_to_4h_gate"
,
)
for
i
in
range
(
num_layers
)
]
)
def
mlp_forward
(
self
,
hidden_states
,
**
kw_args
):
x
=
hidden_states
origin
=
self
.
transformer
.
layers
[
kw_args
[
"layer_id"
]].
mlp
x1
=
origin
.
dense_h_to_4h
(
x
)
x2
=
self
.
w2
[
kw_args
[
"layer_id"
]](
x
)
hidden
=
origin
.
activation_func
(
x2
)
*
x1
x
=
origin
.
dense_4h_to_h
(
hidden
)
return
x
class
AdaLNMixin
(
BaseMixin
):
def
__init__
(
self
,
width
,
height
,
hidden_size
,
num_layers
,
time_embed_dim
,
compressed_num_frames
,
qk_ln
=
True
,
hidden_size_head
=
None
,
elementwise_affine
=
True
,
):
super
().
__init__
()
self
.
num_layers
=
num_layers
self
.
width
=
width
self
.
height
=
height
self
.
compressed_num_frames
=
compressed_num_frames
self
.
adaLN_modulations
=
nn
.
ModuleList
(
[
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
12
*
hidden_size
))
for
_
in
range
(
num_layers
)]
)
self
.
qk_ln
=
qk_ln
if
qk_ln
:
self
.
query_layernorm_list
=
nn
.
ModuleList
(
[
LayerNorm
(
hidden_size_head
,
eps
=
1e-6
,
elementwise_affine
=
elementwise_affine
)
for
_
in
range
(
num_layers
)
]
)
self
.
key_layernorm_list
=
nn
.
ModuleList
(
[
LayerNorm
(
hidden_size_head
,
eps
=
1e-6
,
elementwise_affine
=
elementwise_affine
)
for
_
in
range
(
num_layers
)
]
)
def
layer_forward
(
self
,
hidden_states
,
mask
,
*
args
,
**
kwargs
,
):
text_length
=
kwargs
[
"text_length"
]
# hidden_states (b,(n_t+t*n_i),d)
text_hidden_states
=
hidden_states
[:,
:
text_length
]
# (b,n,d)
img_hidden_states
=
hidden_states
[:,
text_length
:]
# (b,(t n),d)
layer
=
self
.
transformer
.
layers
[
kwargs
[
"layer_id"
]]
adaLN_modulation
=
self
.
adaLN_modulations
[
kwargs
[
"layer_id"
]]
(
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
,
text_shift_msa
,
text_scale_msa
,
text_gate_msa
,
text_shift_mlp
,
text_scale_mlp
,
text_gate_mlp
,
)
=
adaLN_modulation
(
kwargs
[
"emb"
]).
chunk
(
12
,
dim
=
1
)
gate_msa
,
gate_mlp
,
text_gate_msa
,
text_gate_mlp
=
(
gate_msa
.
unsqueeze
(
1
),
gate_mlp
.
unsqueeze
(
1
),
text_gate_msa
.
unsqueeze
(
1
),
text_gate_mlp
.
unsqueeze
(
1
),
)
# self full attention (b,(t n),d) b: batchsize; (t n): temp & spa; d: hidden_size
img_attention_input
=
layer
.
input_layernorm
(
img_hidden_states
)
text_attention_input
=
layer
.
input_layernorm
(
text_hidden_states
)
img_attention_input
=
modulate
(
img_attention_input
,
shift_msa
,
scale_msa
)
text_attention_input
=
modulate
(
text_attention_input
,
text_shift_msa
,
text_scale_msa
)
# Spatial LIEM
_
,
thw
,
_
=
img_attention_input
.
shape
t
=
thw
//
(
self
.
height
*
self
.
width
)
spa_fea
=
rearrange
(
img_attention_input
,
'b (t h w) c -> (b t) c h w'
,
h
=
self
.
height
,
w
=
self
.
width
)
spa_fea
=
layer
.
spa_local
(
spa_fea
)
# Temporal LIEM
temp_fea
=
rearrange
(
spa_fea
,
'(b t) c h w -> (b h w) t c'
,
t
=
t
)
temp_fea
=
layer
.
temp_local
(
temp_fea
)
img_attention_input
=
rearrange
(
temp_fea
,
'(b h w) t c -> b (t h w) c'
,
h
=
self
.
height
,
w
=
self
.
width
)
attention_input
=
torch
.
cat
((
text_attention_input
,
img_attention_input
),
dim
=
1
)
# (b,n_t+t*n_i,d)
attention_output
=
layer
.
attention
(
attention_input
,
mask
,
**
kwargs
)
text_attention_output
=
attention_output
[:,
:
text_length
]
# (b,n,d)
img_attention_output
=
attention_output
[:,
text_length
:]
# (b,(t n),d)
if
self
.
transformer
.
layernorm_order
==
"sandwich"
:
text_attention_output
=
layer
.
third_layernorm
(
text_attention_output
)
img_attention_output
=
layer
.
third_layernorm
(
img_attention_output
)
img_hidden_states
=
img_hidden_states
+
gate_msa
*
img_attention_output
# (b,(t n),d)
text_hidden_states
=
text_hidden_states
+
text_gate_msa
*
text_attention_output
# (b,n,d)
# mlp (b,(t n),d)
img_mlp_input
=
layer
.
post_attention_layernorm
(
img_hidden_states
)
# vision (b,(t n),d)
text_mlp_input
=
layer
.
post_attention_layernorm
(
text_hidden_states
)
# language (b,n,d)
img_mlp_input
=
modulate
(
img_mlp_input
,
shift_mlp
,
scale_mlp
)
text_mlp_input
=
modulate
(
text_mlp_input
,
text_shift_mlp
,
text_scale_mlp
)
mlp_input
=
torch
.
cat
((
text_mlp_input
,
img_mlp_input
),
dim
=
1
)
# (b,(n_t+t*n_i),d
mlp_output
=
layer
.
mlp
(
mlp_input
,
**
kwargs
)
img_mlp_output
=
mlp_output
[:,
text_length
:]
# vision (b,(t n),d)
text_mlp_output
=
mlp_output
[:,
:
text_length
]
# language (b,n,d)
if
self
.
transformer
.
layernorm_order
==
"sandwich"
:
text_mlp_output
=
layer
.
fourth_layernorm
(
text_mlp_output
)
img_mlp_output
=
layer
.
fourth_layernorm
(
img_mlp_output
)
img_hidden_states
=
img_hidden_states
+
gate_mlp
*
img_mlp_output
# vision (b,(t n),d)
text_hidden_states
=
text_hidden_states
+
text_gate_mlp
*
text_mlp_output
# language (b,n,d)
hidden_states
=
torch
.
cat
((
text_hidden_states
,
img_hidden_states
),
dim
=
1
)
# (b,(n_t+t*n_i),d)
return
hidden_states
def
reinit
(
self
,
parent_model
=
None
):
for
layer
in
self
.
adaLN_modulations
:
nn
.
init
.
constant_
(
layer
[
-
1
].
weight
,
0
)
nn
.
init
.
constant_
(
layer
[
-
1
].
bias
,
0
)
@
non_conflict
def
attention_fn
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
,
attention_dropout
=
None
,
log_attention_weights
=
None
,
scaling_attention_score
=
True
,
old_impl
=
attention_fn_default
,
**
kwargs
,
):
if
self
.
qk_ln
:
query_layernorm
=
self
.
query_layernorm_list
[
kwargs
[
"layer_id"
]]
key_layernorm
=
self
.
key_layernorm_list
[
kwargs
[
"layer_id"
]]
query_layer
=
query_layernorm
(
query_layer
)
key_layer
=
key_layernorm
(
key_layer
)
return
old_impl
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
,
attention_dropout
=
attention_dropout
,
log_attention_weights
=
log_attention_weights
,
scaling_attention_score
=
scaling_attention_score
,
**
kwargs
,
)
str_to_dtype
=
{
"fp32"
:
torch
.
float32
,
"fp16"
:
torch
.
float16
,
"bf16"
:
torch
.
bfloat16
}
class
DiffusionTransformer
(
BaseModel
):
def
__init__
(
self
,
transformer_args
,
num_frames
,
time_compressed_rate
,
latent_width
,
latent_height
,
patch_size
,
in_channels
,
out_channels
,
hidden_size
,
num_layers
,
num_attention_heads
,
elementwise_affine
,
time_embed_dim
=
None
,
num_classes
=
None
,
modules
=
{},
input_time
=
"adaln"
,
adm_in_channels
=
None
,
parallel_output
=
True
,
height_interpolation
=
1.0
,
width_interpolation
=
1.0
,
time_interpolation
=
1.0
,
use_SwiGLU
=
False
,
use_RMSNorm
=
False
,
zero_init_y_embed
=
False
,
**
kwargs
,
):
self
.
latent_width
=
latent_width
self
.
latent_height
=
latent_height
self
.
patch_size
=
patch_size
self
.
num_frames
=
num_frames
self
.
time_compressed_rate
=
time_compressed_rate
self
.
spatial_length
=
latent_width
*
latent_height
//
patch_size
**
2
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
hidden_size
=
hidden_size
self
.
model_channels
=
hidden_size
self
.
time_embed_dim
=
time_embed_dim
if
time_embed_dim
is
not
None
else
hidden_size
self
.
num_classes
=
num_classes
self
.
adm_in_channels
=
adm_in_channels
self
.
input_time
=
input_time
self
.
num_layers
=
num_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
is_decoder
=
transformer_args
.
is_decoder
self
.
elementwise_affine
=
elementwise_affine
self
.
height_interpolation
=
height_interpolation
self
.
width_interpolation
=
width_interpolation
self
.
time_interpolation
=
time_interpolation
self
.
inner_hidden_size
=
hidden_size
*
4
self
.
zero_init_y_embed
=
zero_init_y_embed
try
:
self
.
dtype
=
str_to_dtype
[
kwargs
.
pop
(
"dtype"
)]
except
:
self
.
dtype
=
torch
.
float32
if
use_SwiGLU
:
kwargs
[
"activation_func"
]
=
F
.
silu
elif
"activation_func"
not
in
kwargs
:
approx_gelu
=
nn
.
GELU
(
approximate
=
"tanh"
)
kwargs
[
"activation_func"
]
=
approx_gelu
if
use_RMSNorm
:
kwargs
[
"layernorm"
]
=
RMSNorm
else
:
kwargs
[
"layernorm"
]
=
partial
(
LayerNorm
,
elementwise_affine
=
elementwise_affine
,
eps
=
1e-6
)
transformer_args
.
num_layers
=
num_layers
transformer_args
.
hidden_size
=
hidden_size
transformer_args
.
num_attention_heads
=
num_attention_heads
transformer_args
.
parallel_output
=
parallel_output
super
().
__init__
(
args
=
transformer_args
,
transformer
=
None
,
**
kwargs
)
module_configs
=
modules
self
.
_build_modules
(
module_configs
)
if
use_SwiGLU
:
self
.
add_mixin
(
"swiglu"
,
SwiGLUMixin
(
num_layers
,
hidden_size
,
self
.
inner_hidden_size
,
bias
=
False
),
reinit
=
True
)
def
_build_modules
(
self
,
module_configs
):
model_channels
=
self
.
hidden_size
# time_embed_dim = model_channels * 4
time_embed_dim
=
self
.
time_embed_dim
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
if
self
.
num_classes
is
not
None
:
if
isinstance
(
self
.
num_classes
,
int
):
self
.
label_emb
=
nn
.
Embedding
(
self
.
num_classes
,
time_embed_dim
)
elif
self
.
num_classes
==
"continuous"
:
print
(
"setting up linear c_adm embedding layer"
)
self
.
label_emb
=
nn
.
Linear
(
1
,
time_embed_dim
)
elif
self
.
num_classes
==
"timestep"
:
self
.
label_emb
=
nn
.
Sequential
(
Timestep
(
model_channels
),
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
),
)
elif
self
.
num_classes
==
"sequential"
:
assert
self
.
adm_in_channels
is
not
None
self
.
label_emb
=
nn
.
Sequential
(
nn
.
Sequential
(
linear
(
self
.
adm_in_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
)
if
self
.
zero_init_y_embed
:
nn
.
init
.
constant_
(
self
.
label_emb
[
0
][
2
].
weight
,
0
)
nn
.
init
.
constant_
(
self
.
label_emb
[
0
][
2
].
bias
,
0
)
else
:
raise
ValueError
()
pos_embed_config
=
module_configs
[
"pos_embed_config"
]
self
.
add_mixin
(
"pos_embed"
,
instantiate_from_config
(
pos_embed_config
,
height
=
self
.
latent_height
//
self
.
patch_size
,
width
=
self
.
latent_width
//
self
.
patch_size
,
compressed_num_frames
=
(
self
.
num_frames
-
1
)
//
self
.
time_compressed_rate
+
1
,
hidden_size
=
self
.
hidden_size
,
),
reinit
=
True
,
)
patch_embed_config
=
module_configs
[
"patch_embed_config"
]
self
.
add_mixin
(
"patch_embed"
,
instantiate_from_config
(
patch_embed_config
,
patch_size
=
self
.
patch_size
,
hidden_size
=
self
.
hidden_size
,
in_channels
=
self
.
in_channels
,
),
reinit
=
True
,
)
if
self
.
input_time
==
"adaln"
:
adaln_layer_config
=
module_configs
[
"adaln_layer_config"
]
self
.
add_mixin
(
"adaln_layer"
,
instantiate_from_config
(
adaln_layer_config
,
height
=
self
.
latent_height
//
self
.
patch_size
,
width
=
self
.
latent_width
//
self
.
patch_size
,
hidden_size
=
self
.
hidden_size
,
num_layers
=
self
.
num_layers
,
compressed_num_frames
=
(
self
.
num_frames
-
1
)
//
self
.
time_compressed_rate
+
1
,
hidden_size_head
=
self
.
hidden_size
//
self
.
num_attention_heads
,
time_embed_dim
=
self
.
time_embed_dim
,
elementwise_affine
=
self
.
elementwise_affine
,
),
)
else
:
raise
NotImplementedError
final_layer_config
=
module_configs
[
"final_layer_config"
]
self
.
add_mixin
(
"final_layer"
,
instantiate_from_config
(
final_layer_config
,
hidden_size
=
self
.
hidden_size
,
patch_size
=
self
.
patch_size
,
out_channels
=
self
.
out_channels
,
time_embed_dim
=
self
.
time_embed_dim
,
latent_width
=
self
.
latent_width
,
latent_height
=
self
.
latent_height
,
elementwise_affine
=
self
.
elementwise_affine
,
),
reinit
=
True
,
)
if
"lora_config"
in
module_configs
:
lora_config
=
module_configs
[
"lora_config"
]
self
.
add_mixin
(
"lora"
,
instantiate_from_config
(
lora_config
,
layer_num
=
self
.
num_layers
),
reinit
=
True
)
return
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
**
kwargs
):
# print('x shape:', x.shape) # train phase: torch.Size([2, 8, 32, 60, 90])
b
,
t
,
d
,
h
,
w
=
x
.
shape
if
x
.
dtype
!=
self
.
dtype
:
x
=
x
.
to
(
self
.
dtype
)
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
t_emb
=
timestep_embedding
(
timesteps
,
self
.
model_channels
,
repeat_only
=
False
,
dtype
=
self
.
dtype
)
emb
=
self
.
time_embed
(
t_emb
)
if
self
.
num_classes
is
not
None
:
# assert y.shape[0] == x.shape[0]
assert
x
.
shape
[
0
]
%
y
.
shape
[
0
]
==
0
y
=
y
.
repeat_interleave
(
x
.
shape
[
0
]
//
y
.
shape
[
0
],
dim
=
0
)
emb
=
emb
+
self
.
label_emb
(
y
)
kwargs
[
"seq_length"
]
=
t
*
h
*
w
//
(
self
.
patch_size
**
2
)
kwargs
[
"images"
]
=
x
kwargs
[
"emb"
]
=
emb
kwargs
[
"encoder_outputs"
]
=
context
kwargs
[
"text_length"
]
=
context
.
shape
[
1
]
kwargs
[
"input_ids"
]
=
kwargs
[
"position_ids"
]
=
kwargs
[
"attention_mask"
]
=
torch
.
ones
((
1
,
1
)).
to
(
x
.
dtype
)
output
=
super
().
forward
(
**
kwargs
)[
0
]
return
output
\ No newline at end of file
cogvideox-based/sat/inference_sr.sh
0 → 100755
View file @
1f5da520
#! /bin/bash
echo
"CUDA_VISIBLE_DEVICES=
$CUDA_VISIBLE_DEVICES
"
environs
=
"WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
run_cmd
=
"
$environs
python sample_sr.py --base configs/cogvideox_5b/cogvideox_5b_infer_sr.yaml"
echo
${
run_cmd
}
eval
${
run_cmd
}
echo
"DONE on
`
hostname
`
"
\ No newline at end of file
cogvideox-based/sat/requirements.txt
0 → 100644
View file @
1f5da520
SwissArmyTransformer==0.4.12
omegaconf==2.3.0
torch==2.4.0
torchvision==0.19.0
pytorch_lightning==2.3.3
kornia==0.7.3
beartype==0.18.5
numpy==2.0.1
fsspec==2024.5.0
safetensors==0.4.3
imageio-ffmpeg==0.5.1
imageio==2.34.2
# scipy==1.14.0
decord==0.6.0
wandb==0.17.5
deepspeed==0.14.4
\ No newline at end of file
cogvideox-based/sat/sample_sr.py
0 → 100644
View file @
1f5da520
import
os
import
math
import
argparse
from
typing
import
List
,
Union
from
tqdm
import
tqdm
from
omegaconf
import
ListConfig
import
imageio
import
torch
from
einops
import
rearrange
import
numpy
as
np
from
einops
import
rearrange
import
torchvision.transforms
as
TT
from
sat.model.base_model
import
get_model
from
sat.training.model_io
import
load_checkpoint
from
sat
import
mpu
from
diffusion_video
import
SATVideoDiffusionEngine
from
arguments
import
get_args
from
torchvision.transforms.functional
import
center_crop
,
resize
from
torchvision.transforms
import
InterpolationMode
from
data_video
import
PairedCaptionDataset
from
color_fix
import
adain_color_fix
def
read_from_cli
():
cnt
=
0
try
:
while
True
:
x
=
input
(
"Please input English text (Ctrl-D quit): "
)
yield
x
.
strip
(),
cnt
cnt
+=
1
except
EOFError
as
e
:
pass
def
read_from_file
(
p
,
rank
=
0
,
world_size
=
1
):
with
open
(
p
,
"r"
)
as
fin
:
cnt
=
-
1
for
l
in
fin
:
cnt
+=
1
if
cnt
%
world_size
!=
rank
:
continue
yield
l
.
strip
(),
cnt
def
get_unique_embedder_keys_from_conditioner
(
conditioner
):
return
list
(
set
([
x
.
input_key
for
x
in
conditioner
.
embedders
]))
def
get_batch
(
keys
,
value_dict
,
N
:
Union
[
List
,
ListConfig
],
T
=
None
,
device
=
"cuda"
):
batch
=
{}
batch_uc
=
{}
for
key
in
keys
:
if
key
==
"txt"
:
batch
[
"txt"
]
=
np
.
repeat
([
value_dict
[
"prompt"
]],
repeats
=
math
.
prod
(
N
)).
reshape
(
N
).
tolist
()
batch_uc
[
"txt"
]
=
np
.
repeat
([
value_dict
[
"negative_prompt"
]],
repeats
=
math
.
prod
(
N
)).
reshape
(
N
).
tolist
()
else
:
batch
[
key
]
=
value_dict
[
key
]
if
T
is
not
None
:
batch
[
"num_video_frames"
]
=
T
for
key
in
batch
.
keys
():
if
key
not
in
batch_uc
and
isinstance
(
batch
[
key
],
torch
.
Tensor
):
batch_uc
[
key
]
=
torch
.
clone
(
batch
[
key
])
return
batch
,
batch_uc
def
save_video_as_grid_and_mp4
(
video_batch
:
torch
.
Tensor
,
save_path
:
str
,
fps
:
int
=
5
,
args
=
None
,
key
=
None
):
os
.
makedirs
(
save_path
,
exist_ok
=
True
)
for
i
,
vid
in
enumerate
(
video_batch
):
gif_frames
=
[]
for
frame
in
vid
:
frame
=
rearrange
(
frame
,
"c h w -> h w c"
)
frame
=
(
255.0
*
frame
).
cpu
().
numpy
().
astype
(
np
.
uint8
)
gif_frames
.
append
(
frame
)
now_save_path
=
os
.
path
.
join
(
save_path
,
f
"
{
i
:
06
d
}
.mp4"
)
with
imageio
.
get_writer
(
now_save_path
,
fps
=
fps
,
quality
=
10
)
as
writer
:
for
frame
in
gif_frames
:
writer
.
append_data
(
frame
)
def
resize_for_rectangle_crop
(
arr
,
image_size
,
reshape_mode
=
"random"
):
if
arr
.
shape
[
3
]
/
arr
.
shape
[
2
]
>
image_size
[
1
]
/
image_size
[
0
]:
arr
=
resize
(
arr
,
size
=
[
image_size
[
0
],
int
(
arr
.
shape
[
3
]
*
image_size
[
0
]
/
arr
.
shape
[
2
])],
interpolation
=
InterpolationMode
.
BICUBIC
,
)
else
:
arr
=
resize
(
arr
,
size
=
[
int
(
arr
.
shape
[
2
]
*
image_size
[
1
]
/
arr
.
shape
[
3
]),
image_size
[
1
]],
interpolation
=
InterpolationMode
.
BICUBIC
,
)
h
,
w
=
arr
.
shape
[
2
],
arr
.
shape
[
3
]
arr
=
arr
.
squeeze
(
0
)
delta_h
=
h
-
image_size
[
0
]
delta_w
=
w
-
image_size
[
1
]
if
reshape_mode
==
"random"
or
reshape_mode
==
"none"
:
top
=
np
.
random
.
randint
(
0
,
delta_h
+
1
)
left
=
np
.
random
.
randint
(
0
,
delta_w
+
1
)
elif
reshape_mode
==
"center"
:
top
,
left
=
delta_h
//
2
,
delta_w
//
2
else
:
raise
NotImplementedError
arr
=
TT
.
functional
.
crop
(
arr
,
top
=
top
,
left
=
left
,
height
=
image_size
[
0
],
width
=
image_size
[
1
])
return
arr
def
sampling_main
(
args
,
model_cls
):
test_dataset
=
PairedCaptionDataset
(
data_dir
=
'/mnt/bn/videodataset/VSR/dataset/VSRTest/cogvideox_test'
,
null_text_ratio
=
0
,
num_frames
=
25
)
test_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
test_dataset
,
num_workers
=
8
,
batch_size
=
1
,
shuffle
=
False
)
if
isinstance
(
model_cls
,
type
):
model
=
get_model
(
args
,
model_cls
)
else
:
model
=
model_cls
load_checkpoint
(
model
,
args
)
model
.
eval
()
if
args
.
input_type
==
"cli"
:
data_iter
=
read_from_cli
()
elif
args
.
input_type
==
"txt"
:
rank
,
world_size
=
mpu
.
get_data_parallel_rank
(),
mpu
.
get_data_parallel_world_size
()
print
(
"rank and world_size"
,
rank
,
world_size
)
data_iter
=
read_from_file
(
args
.
input_file
,
rank
=
rank
,
world_size
=
world_size
)
else
:
raise
NotImplementedError
image_size
=
[
480
,
720
]
sample_func
=
model
.
sample_sr
T
,
H
,
W
,
C
,
F
=
args
.
sampling_num_frames
,
image_size
[
0
],
image_size
[
1
],
args
.
latent_channels
,
8
num_samples
=
[
1
]
force_uc_zero_embeddings
=
[
"txt"
]
device
=
model
.
device
with
torch
.
no_grad
():
for
step
,
batch
in
enumerate
(
test_dataloader
):
cnt
=
step
gt
=
batch
[
'mp4'
]
text
=
batch
[
'txt'
]
lq
=
batch
[
'lq'
]
fps
=
batch
[
'fps'
]
# reload model on GPU
model
.
to
(
device
)
print
(
"rank:"
,
rank
,
"start to process"
,
text
,
cnt
)
# TODO: broadcast image2video
value_dict
=
{
"prompt"
:
text
,
"negative_prompt"
:
""
,
"num_frames"
:
torch
.
tensor
(
T
).
unsqueeze
(
0
),
}
batch
,
batch_uc
=
get_batch
(
get_unique_embedder_keys_from_conditioner
(
model
.
conditioner
),
value_dict
,
num_samples
)
for
key
in
batch
:
if
isinstance
(
batch
[
key
],
torch
.
Tensor
):
print
(
key
,
batch
[
key
].
shape
)
elif
isinstance
(
batch
[
key
],
list
):
print
(
key
,
[
len
(
l
)
for
l
in
batch
[
key
]])
else
:
print
(
key
,
batch
[
key
])
c
,
uc
=
model
.
conditioner
.
get_unconditional_conditioning
(
batch
,
batch_uc
=
batch_uc
,
force_uc_zero_embeddings
=
force_uc_zero_embeddings
,
)
for
k
in
c
:
if
not
k
==
"crossattn"
:
c
[
k
],
uc
[
k
]
=
map
(
lambda
y
:
y
[
k
][:
math
.
prod
(
num_samples
)].
to
(
"cuda"
),
(
c
,
uc
))
for
index
in
range
(
args
.
batch_size
):
# reload model on GPU
model
.
to
(
device
)
samples_z
=
sample_func
(
c
,
uc
=
uc
,
batch_size
=
1
,
shape
=
(
T
,
C
,
H
//
F
,
W
//
F
),
lq
=
lq
,
)
samples_z
=
samples_z
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
# print('max samples_z:', torch.max(samples_z)) # 3.0996
# print('min samples_z:', torch.min(samples_z)) # -3.0742
# Unload the model from GPU to save GPU memory
model
.
to
(
"cpu"
)
torch
.
cuda
.
empty_cache
()
first_stage_model
=
model
.
first_stage_model
first_stage_model
=
first_stage_model
.
to
(
device
)
latent
=
1.0
/
model
.
scale_factor
*
samples_z
# Decode latent serial to save GPU memory
print
(
'latent shape:'
,
latent
.
shape
)
recons
=
[]
loop_num
=
(
T
-
1
)
//
2
for
i
in
range
(
loop_num
):
if
i
==
0
:
start_frame
,
end_frame
=
0
,
3
else
:
start_frame
,
end_frame
=
i
*
2
+
1
,
i
*
2
+
3
if
i
==
loop_num
-
1
:
clear_fake_cp_cache
=
True
else
:
clear_fake_cp_cache
=
False
with
torch
.
no_grad
():
recon
=
first_stage_model
.
decode
(
latent
[:,
:,
start_frame
:
end_frame
].
contiguous
(),
clear_fake_cp_cache
=
clear_fake_cp_cache
)
recons
.
append
(
recon
)
recon
=
torch
.
cat
(
recons
,
dim
=
2
).
to
(
torch
.
float32
)
samples_x
=
recon
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
samples
=
torch
.
clamp
((
samples_x
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
).
cpu
()
# Using color fix
samples
=
adain_color_fix
(
samples
,
gt
)
# samples,lq: (b, t, c, h, w)
save_path
=
os
.
path
.
join
(
args
.
output_dir
,
str
(
cnt
)
+
"_"
+
text
[
0
].
replace
(
" "
,
"_"
).
replace
(
"/"
,
""
)[:
120
]
)
save_path_gt
=
os
.
path
.
join
(
args
.
output_dir
,
str
(
cnt
)
+
"_gt_"
+
text
[
0
].
replace
(
" "
,
"_"
).
replace
(
"/"
,
""
)[:
120
]
)
save_path_lq
=
os
.
path
.
join
(
args
.
output_dir
,
str
(
cnt
)
+
"_lq_"
+
text
[
0
].
replace
(
" "
,
"_"
).
replace
(
"/"
,
""
)[:
120
]
)
if
mpu
.
get_model_parallel_rank
()
==
0
:
save_video_as_grid_and_mp4
(
samples
,
save_path
,
fps
=
float
(
fps
))
# save_video_as_grid_and_mp4(torch.clamp((gt + 1.0) / 2.0, min=0.0, max=1.0).cpu(), save_path_gt, fps=float(fps))
# save_video_as_grid_and_mp4(torch.clamp((lq + 1.0) / 2.0, min=0.0, max=1.0).cpu(), save_path_lq, fps=float(fps))
if
__name__
==
"__main__"
:
if
"OMPI_COMM_WORLD_LOCAL_RANK"
in
os
.
environ
:
os
.
environ
[
"LOCAL_RANK"
]
=
os
.
environ
[
"OMPI_COMM_WORLD_LOCAL_RANK"
]
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
[
"OMPI_COMM_WORLD_SIZE"
]
os
.
environ
[
"RANK"
]
=
os
.
environ
[
"OMPI_COMM_WORLD_RANK"
]
py_parser
=
argparse
.
ArgumentParser
(
add_help
=
False
)
known
,
args_list
=
py_parser
.
parse_known_args
()
args
=
get_args
(
args_list
)
args
=
argparse
.
Namespace
(
**
vars
(
args
),
**
vars
(
known
))
del
args
.
deepspeed_config
args
.
model_config
.
first_stage_config
.
params
.
cp_size
=
1
args
.
model_config
.
network_config
.
params
.
transformer_args
.
model_parallel_size
=
1
args
.
model_config
.
network_config
.
params
.
transformer_args
.
checkpoint_activations
=
False
args
.
model_config
.
loss_fn_config
.
params
.
sigma_sampler_config
.
params
.
uniform_sampling
=
False
sampling_main
(
args
,
model_cls
=
SATVideoDiffusionEngine
)
cogvideox-based/sat/sgm/__init__.py
0 → 100644
View file @
1f5da520
from
.models
import
AutoencodingEngine
from
.util
import
get_configs_path
,
instantiate_from_config
__version__
=
"0.1.0"
cogvideox-based/sat/sgm/__pycache__/__init__.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
cogvideox-based/sat/sgm/__pycache__/util.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
cogvideox-based/sat/sgm/__pycache__/webds.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
cogvideox-based/sat/sgm/lr_scheduler.py
0 → 100644
View file @
1f5da520
import
numpy
as
np
class
LambdaWarmUpCosineScheduler
:
"""
note: use with a base_lr of 1.0
"""
def
__init__
(
self
,
warm_up_steps
,
lr_min
,
lr_max
,
lr_start
,
max_decay_steps
,
verbosity_interval
=
0
,
):
self
.
lr_warm_up_steps
=
warm_up_steps
self
.
lr_start
=
lr_start
self
.
lr_min
=
lr_min
self
.
lr_max
=
lr_max
self
.
lr_max_decay_steps
=
max_decay_steps
self
.
last_lr
=
0.0
self
.
verbosity_interval
=
verbosity_interval
def
schedule
(
self
,
n
,
**
kwargs
):
if
self
.
verbosity_interval
>
0
:
if
n
%
self
.
verbosity_interval
==
0
:
print
(
f
"current step:
{
n
}
, recent lr-multiplier:
{
self
.
last_lr
}
"
)
if
n
<
self
.
lr_warm_up_steps
:
lr
=
(
self
.
lr_max
-
self
.
lr_start
)
/
self
.
lr_warm_up_steps
*
n
+
self
.
lr_start
self
.
last_lr
=
lr
return
lr
else
:
t
=
(
n
-
self
.
lr_warm_up_steps
)
/
(
self
.
lr_max_decay_steps
-
self
.
lr_warm_up_steps
)
t
=
min
(
t
,
1.0
)
lr
=
self
.
lr_min
+
0.5
*
(
self
.
lr_max
-
self
.
lr_min
)
*
(
1
+
np
.
cos
(
t
*
np
.
pi
))
self
.
last_lr
=
lr
return
lr
def
__call__
(
self
,
n
,
**
kwargs
):
return
self
.
schedule
(
n
,
**
kwargs
)
class
LambdaWarmUpCosineScheduler2
:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def
__init__
(
self
,
warm_up_steps
,
f_min
,
f_max
,
f_start
,
cycle_lengths
,
verbosity_interval
=
0
):
assert
len
(
warm_up_steps
)
==
len
(
f_min
)
==
len
(
f_max
)
==
len
(
f_start
)
==
len
(
cycle_lengths
)
self
.
lr_warm_up_steps
=
warm_up_steps
self
.
f_start
=
f_start
self
.
f_min
=
f_min
self
.
f_max
=
f_max
self
.
cycle_lengths
=
cycle_lengths
self
.
cum_cycles
=
np
.
cumsum
([
0
]
+
list
(
self
.
cycle_lengths
))
self
.
last_f
=
0.0
self
.
verbosity_interval
=
verbosity_interval
def
find_in_interval
(
self
,
n
):
interval
=
0
for
cl
in
self
.
cum_cycles
[
1
:]:
if
n
<=
cl
:
return
interval
interval
+=
1
def
schedule
(
self
,
n
,
**
kwargs
):
cycle
=
self
.
find_in_interval
(
n
)
n
=
n
-
self
.
cum_cycles
[
cycle
]
if
self
.
verbosity_interval
>
0
:
if
n
%
self
.
verbosity_interval
==
0
:
print
(
f
"current step:
{
n
}
, recent lr-multiplier:
{
self
.
last_f
}
, "
f
"current cycle
{
cycle
}
"
)
if
n
<
self
.
lr_warm_up_steps
[
cycle
]:
f
=
(
self
.
f_max
[
cycle
]
-
self
.
f_start
[
cycle
])
/
self
.
lr_warm_up_steps
[
cycle
]
*
n
+
self
.
f_start
[
cycle
]
self
.
last_f
=
f
return
f
else
:
t
=
(
n
-
self
.
lr_warm_up_steps
[
cycle
])
/
(
self
.
cycle_lengths
[
cycle
]
-
self
.
lr_warm_up_steps
[
cycle
])
t
=
min
(
t
,
1.0
)
f
=
self
.
f_min
[
cycle
]
+
0.5
*
(
self
.
f_max
[
cycle
]
-
self
.
f_min
[
cycle
])
*
(
1
+
np
.
cos
(
t
*
np
.
pi
))
self
.
last_f
=
f
return
f
def
__call__
(
self
,
n
,
**
kwargs
):
return
self
.
schedule
(
n
,
**
kwargs
)
class
LambdaLinearScheduler
(
LambdaWarmUpCosineScheduler2
):
def
schedule
(
self
,
n
,
**
kwargs
):
cycle
=
self
.
find_in_interval
(
n
)
n
=
n
-
self
.
cum_cycles
[
cycle
]
if
self
.
verbosity_interval
>
0
:
if
n
%
self
.
verbosity_interval
==
0
:
print
(
f
"current step:
{
n
}
, recent lr-multiplier:
{
self
.
last_f
}
, "
f
"current cycle
{
cycle
}
"
)
if
n
<
self
.
lr_warm_up_steps
[
cycle
]:
f
=
(
self
.
f_max
[
cycle
]
-
self
.
f_start
[
cycle
])
/
self
.
lr_warm_up_steps
[
cycle
]
*
n
+
self
.
f_start
[
cycle
]
self
.
last_f
=
f
return
f
else
:
f
=
(
self
.
f_min
[
cycle
]
+
(
self
.
f_max
[
cycle
]
-
self
.
f_min
[
cycle
])
*
(
self
.
cycle_lengths
[
cycle
]
-
n
)
/
(
self
.
cycle_lengths
[
cycle
])
)
self
.
last_f
=
f
return
f
cogvideox-based/sat/sgm/models/__init__.py
0 → 100644
View file @
1f5da520
from
.autoencoder
import
AutoencodingEngine
cogvideox-based/sat/sgm/models/__pycache__/__init__.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
cogvideox-based/sat/sgm/models/__pycache__/autoencoder.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
cogvideox-based/sat/sgm/models/autoencoder.py
0 → 100644
View file @
1f5da520
import
logging
import
math
import
re
import
random
from
abc
import
abstractmethod
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
pytorch_lightning
as
pl
import
torch
import
torch.distributed
import
torch.nn
as
nn
from
einops
import
rearrange
from
packaging
import
version
from
..modules.autoencoding.regularizers
import
AbstractRegularizer
from
..modules.ema
import
LitEma
from
..util
import
(
default
,
get_nested_attribute
,
get_obj_from_str
,
instantiate_from_config
,
initialize_context_parallel
,
get_context_parallel_group
,
get_context_parallel_group_rank
,
is_context_parallel_initialized
,
)
from
..modules.cp_enc_dec
import
_conv_split
,
_conv_gather
logpy
=
logging
.
getLogger
(
__name__
)
class
AbstractAutoencoder
(
pl
.
LightningModule
):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
"""
def
__init__
(
self
,
ema_decay
:
Union
[
None
,
float
]
=
None
,
monitor
:
Union
[
None
,
str
]
=
None
,
input_key
:
str
=
"jpg"
,
):
super
().
__init__
()
self
.
input_key
=
input_key
self
.
use_ema
=
ema_decay
is
not
None
if
monitor
is
not
None
:
self
.
monitor
=
monitor
if
self
.
use_ema
:
self
.
model_ema
=
LitEma
(
self
,
decay
=
ema_decay
)
logpy
.
info
(
f
"Keeping EMAs of
{
len
(
list
(
self
.
model_ema
.
buffers
()))
}
."
)
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"2.0.0"
):
self
.
automatic_optimization
=
False
def
apply_ckpt
(
self
,
ckpt
:
Union
[
None
,
str
,
dict
]):
if
ckpt
is
None
:
return
if
isinstance
(
ckpt
,
str
):
ckpt
=
{
"target"
:
"sgm.modules.checkpoint.CheckpointEngine"
,
"params"
:
{
"ckpt_path"
:
ckpt
},
}
engine
=
instantiate_from_config
(
ckpt
)
engine
(
self
)
@
abstractmethod
def
get_input
(
self
,
batch
)
->
Any
:
raise
NotImplementedError
()
def
on_train_batch_end
(
self
,
*
args
,
**
kwargs
):
# for EMA computation
if
self
.
use_ema
:
self
.
model_ema
(
self
)
@
contextmanager
def
ema_scope
(
self
,
context
=
None
):
if
self
.
use_ema
:
self
.
model_ema
.
store
(
self
.
parameters
())
self
.
model_ema
.
copy_to
(
self
)
if
context
is
not
None
:
logpy
.
info
(
f
"
{
context
}
: Switched to EMA weights"
)
try
:
yield
None
finally
:
if
self
.
use_ema
:
self
.
model_ema
.
restore
(
self
.
parameters
())
if
context
is
not
None
:
logpy
.
info
(
f
"
{
context
}
: Restored training weights"
)
@
abstractmethod
def
encode
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
raise
NotImplementedError
(
"encode()-method of abstract base class called"
)
@
abstractmethod
def
decode
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
raise
NotImplementedError
(
"decode()-method of abstract base class called"
)
def
instantiate_optimizer_from_config
(
self
,
params
,
lr
,
cfg
):
logpy
.
info
(
f
"loading >>>
{
cfg
[
'target'
]
}
<<< optimizer from config"
)
return
get_obj_from_str
(
cfg
[
"target"
])(
params
,
lr
=
lr
,
**
cfg
.
get
(
"params"
,
dict
()))
def
configure_optimizers
(
self
)
->
Any
:
raise
NotImplementedError
()
class
AutoencodingEngine
(
AbstractAutoencoder
):
"""
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
(we also restore them explicitly as special cases for legacy reasons).
Regularizations such as KL or VQ are moved to the regularizer class.
"""
def
__init__
(
self
,
*
args
,
encoder_config
:
Dict
,
decoder_config
:
Dict
,
loss_config
:
Dict
,
regularizer_config
:
Dict
,
optimizer_config
:
Union
[
Dict
,
None
]
=
None
,
lr_g_factor
:
float
=
1.0
,
trainable_ae_params
:
Optional
[
List
[
List
[
str
]]]
=
None
,
ae_optimizer_args
:
Optional
[
List
[
dict
]]
=
None
,
trainable_disc_params
:
Optional
[
List
[
List
[
str
]]]
=
None
,
disc_optimizer_args
:
Optional
[
List
[
dict
]]
=
None
,
disc_start_iter
:
int
=
0
,
diff_boost_factor
:
float
=
3.0
,
ckpt_engine
:
Union
[
None
,
str
,
dict
]
=
None
,
ckpt_path
:
Optional
[
str
]
=
None
,
additional_decode_keys
:
Optional
[
List
[
str
]]
=
None
,
**
kwargs
,
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
automatic_optimization
=
False
# pytorch lightning
self
.
encoder
:
torch
.
nn
.
Module
=
instantiate_from_config
(
encoder_config
)
self
.
decoder
:
torch
.
nn
.
Module
=
instantiate_from_config
(
decoder_config
)
self
.
loss
:
torch
.
nn
.
Module
=
instantiate_from_config
(
loss_config
)
self
.
regularization
:
AbstractRegularizer
=
instantiate_from_config
(
regularizer_config
)
self
.
optimizer_config
=
default
(
optimizer_config
,
{
"target"
:
"torch.optim.Adam"
})
self
.
diff_boost_factor
=
diff_boost_factor
self
.
disc_start_iter
=
disc_start_iter
self
.
lr_g_factor
=
lr_g_factor
self
.
trainable_ae_params
=
trainable_ae_params
if
self
.
trainable_ae_params
is
not
None
:
self
.
ae_optimizer_args
=
default
(
ae_optimizer_args
,
[{}
for
_
in
range
(
len
(
self
.
trainable_ae_params
))],
)
assert
len
(
self
.
ae_optimizer_args
)
==
len
(
self
.
trainable_ae_params
)
else
:
self
.
ae_optimizer_args
=
[{}]
# makes type consitent
self
.
trainable_disc_params
=
trainable_disc_params
if
self
.
trainable_disc_params
is
not
None
:
self
.
disc_optimizer_args
=
default
(
disc_optimizer_args
,
[{}
for
_
in
range
(
len
(
self
.
trainable_disc_params
))],
)
assert
len
(
self
.
disc_optimizer_args
)
==
len
(
self
.
trainable_disc_params
)
else
:
self
.
disc_optimizer_args
=
[{}]
# makes type consitent
if
ckpt_path
is
not
None
:
assert
ckpt_engine
is
None
,
"Can't set ckpt_engine and ckpt_path"
logpy
.
warn
(
"Checkpoint path is deprecated, use `checkpoint_egnine` instead"
)
self
.
apply_ckpt
(
default
(
ckpt_path
,
ckpt_engine
))
self
.
additional_decode_keys
=
set
(
default
(
additional_decode_keys
,
[]))
def
get_input
(
self
,
batch
:
Dict
)
->
torch
.
Tensor
:
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in channels-first
# format (e.g., bchw instead if bhwc)
return
batch
[
self
.
input_key
]
def
get_autoencoder_params
(
self
)
->
list
:
params
=
[]
if
hasattr
(
self
.
loss
,
"get_trainable_autoencoder_parameters"
):
params
+=
list
(
self
.
loss
.
get_trainable_autoencoder_parameters
())
if
hasattr
(
self
.
regularization
,
"get_trainable_parameters"
):
params
+=
list
(
self
.
regularization
.
get_trainable_parameters
())
params
=
params
+
list
(
self
.
encoder
.
parameters
())
params
=
params
+
list
(
self
.
decoder
.
parameters
())
return
params
def
get_discriminator_params
(
self
)
->
list
:
if
hasattr
(
self
.
loss
,
"get_trainable_parameters"
):
params
=
list
(
self
.
loss
.
get_trainable_parameters
())
# e.g., discriminator
else
:
params
=
[]
return
params
def
get_last_layer
(
self
):
return
self
.
decoder
.
get_last_layer
()
def
encode
(
self
,
x
:
torch
.
Tensor
,
return_reg_log
:
bool
=
False
,
unregularized
:
bool
=
False
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
dict
]]:
z
=
self
.
encoder
(
x
,
**
kwargs
)
if
unregularized
:
return
z
,
dict
()
z
,
reg_log
=
self
.
regularization
(
z
)
if
return_reg_log
:
return
z
,
reg_log
return
z
def
decode
(
self
,
z
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
x
=
self
.
decoder
(
z
,
**
kwargs
)
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
additional_decode_kwargs
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
dict
]:
z
,
reg_log
=
self
.
encode
(
x
,
return_reg_log
=
True
)
dec
=
self
.
decode
(
z
,
**
additional_decode_kwargs
)
return
z
,
dec
,
reg_log
def
inner_training_step
(
self
,
batch
:
dict
,
batch_idx
:
int
,
optimizer_idx
:
int
=
0
)
->
torch
.
Tensor
:
x
=
self
.
get_input
(
batch
)
additional_decode_kwargs
=
{
key
:
batch
[
key
]
for
key
in
self
.
additional_decode_keys
.
intersection
(
batch
)}
z
,
xrec
,
regularization_log
=
self
(
x
,
**
additional_decode_kwargs
)
if
hasattr
(
self
.
loss
,
"forward_keys"
):
extra_info
=
{
"z"
:
z
,
"optimizer_idx"
:
optimizer_idx
,
"global_step"
:
self
.
global_step
,
"last_layer"
:
self
.
get_last_layer
(),
"split"
:
"train"
,
"regularization_log"
:
regularization_log
,
"autoencoder"
:
self
,
}
extra_info
=
{
k
:
extra_info
[
k
]
for
k
in
self
.
loss
.
forward_keys
}
else
:
extra_info
=
dict
()
if
optimizer_idx
==
0
:
# autoencode
out_loss
=
self
.
loss
(
x
,
xrec
,
**
extra_info
)
if
isinstance
(
out_loss
,
tuple
):
aeloss
,
log_dict_ae
=
out_loss
else
:
# simple loss function
aeloss
=
out_loss
log_dict_ae
=
{
"train/loss/rec"
:
aeloss
.
detach
()}
self
.
log_dict
(
log_dict_ae
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
,
sync_dist
=
False
,
)
self
.
log
(
"loss"
,
aeloss
.
mean
().
detach
(),
prog_bar
=
True
,
logger
=
False
,
on_epoch
=
False
,
on_step
=
True
,
)
return
aeloss
elif
optimizer_idx
==
1
:
# discriminator
discloss
,
log_dict_disc
=
self
.
loss
(
x
,
xrec
,
**
extra_info
)
# -> discriminator always needs to return a tuple
self
.
log_dict
(
log_dict_disc
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
)
return
discloss
else
:
raise
NotImplementedError
(
f
"Unknown optimizer
{
optimizer_idx
}
"
)
def
training_step
(
self
,
batch
:
dict
,
batch_idx
:
int
):
opts
=
self
.
optimizers
()
if
not
isinstance
(
opts
,
list
):
# Non-adversarial case
opts
=
[
opts
]
optimizer_idx
=
batch_idx
%
len
(
opts
)
if
self
.
global_step
<
self
.
disc_start_iter
:
optimizer_idx
=
0
opt
=
opts
[
optimizer_idx
]
opt
.
zero_grad
()
with
opt
.
toggle_model
():
loss
=
self
.
inner_training_step
(
batch
,
batch_idx
,
optimizer_idx
=
optimizer_idx
)
self
.
manual_backward
(
loss
)
opt
.
step
()
def
validation_step
(
self
,
batch
:
dict
,
batch_idx
:
int
)
->
Dict
:
log_dict
=
self
.
_validation_step
(
batch
,
batch_idx
)
with
self
.
ema_scope
():
log_dict_ema
=
self
.
_validation_step
(
batch
,
batch_idx
,
postfix
=
"_ema"
)
log_dict
.
update
(
log_dict_ema
)
return
log_dict
def
_validation_step
(
self
,
batch
:
dict
,
batch_idx
:
int
,
postfix
:
str
=
""
)
->
Dict
:
x
=
self
.
get_input
(
batch
)
z
,
xrec
,
regularization_log
=
self
(
x
)
if
hasattr
(
self
.
loss
,
"forward_keys"
):
extra_info
=
{
"z"
:
z
,
"optimizer_idx"
:
0
,
"global_step"
:
self
.
global_step
,
"last_layer"
:
self
.
get_last_layer
(),
"split"
:
"val"
+
postfix
,
"regularization_log"
:
regularization_log
,
"autoencoder"
:
self
,
}
extra_info
=
{
k
:
extra_info
[
k
]
for
k
in
self
.
loss
.
forward_keys
}
else
:
extra_info
=
dict
()
out_loss
=
self
.
loss
(
x
,
xrec
,
**
extra_info
)
if
isinstance
(
out_loss
,
tuple
):
aeloss
,
log_dict_ae
=
out_loss
else
:
# simple loss function
aeloss
=
out_loss
log_dict_ae
=
{
f
"val
{
postfix
}
/loss/rec"
:
aeloss
.
detach
()}
full_log_dict
=
log_dict_ae
if
"optimizer_idx"
in
extra_info
:
extra_info
[
"optimizer_idx"
]
=
1
discloss
,
log_dict_disc
=
self
.
loss
(
x
,
xrec
,
**
extra_info
)
full_log_dict
.
update
(
log_dict_disc
)
self
.
log
(
f
"val
{
postfix
}
/loss/rec"
,
log_dict_ae
[
f
"val
{
postfix
}
/loss/rec"
],
sync_dist
=
True
,
)
self
.
log_dict
(
full_log_dict
,
sync_dist
=
True
)
return
full_log_dict
def
get_param_groups
(
self
,
parameter_names
:
List
[
List
[
str
]],
optimizer_args
:
List
[
dict
]
)
->
Tuple
[
List
[
Dict
[
str
,
Any
]],
int
]:
groups
=
[]
num_params
=
0
for
names
,
args
in
zip
(
parameter_names
,
optimizer_args
):
params
=
[]
for
pattern_
in
names
:
pattern_params
=
[]
pattern
=
re
.
compile
(
pattern_
)
for
p_name
,
param
in
self
.
named_parameters
():
if
re
.
match
(
pattern
,
p_name
):
pattern_params
.
append
(
param
)
num_params
+=
param
.
numel
()
if
len
(
pattern_params
)
==
0
:
logpy
.
warn
(
f
"Did not find parameters for pattern
{
pattern_
}
"
)
params
.
extend
(
pattern_params
)
groups
.
append
({
"params"
:
params
,
**
args
})
return
groups
,
num_params
def
configure_optimizers
(
self
)
->
List
[
torch
.
optim
.
Optimizer
]:
if
self
.
trainable_ae_params
is
None
:
ae_params
=
self
.
get_autoencoder_params
()
else
:
ae_params
,
num_ae_params
=
self
.
get_param_groups
(
self
.
trainable_ae_params
,
self
.
ae_optimizer_args
)
logpy
.
info
(
f
"Number of trainable autoencoder parameters:
{
num_ae_params
:,
}
"
)
if
self
.
trainable_disc_params
is
None
:
disc_params
=
self
.
get_discriminator_params
()
else
:
disc_params
,
num_disc_params
=
self
.
get_param_groups
(
self
.
trainable_disc_params
,
self
.
disc_optimizer_args
)
logpy
.
info
(
f
"Number of trainable discriminator parameters:
{
num_disc_params
:,
}
"
)
opt_ae
=
self
.
instantiate_optimizer_from_config
(
ae_params
,
default
(
self
.
lr_g_factor
,
1.0
)
*
self
.
learning_rate
,
self
.
optimizer_config
,
)
opts
=
[
opt_ae
]
if
len
(
disc_params
)
>
0
:
opt_disc
=
self
.
instantiate_optimizer_from_config
(
disc_params
,
self
.
learning_rate
,
self
.
optimizer_config
)
opts
.
append
(
opt_disc
)
return
opts
@
torch
.
no_grad
()
def
log_images
(
self
,
batch
:
dict
,
additional_log_kwargs
:
Optional
[
Dict
]
=
None
,
**
kwargs
)
->
dict
:
log
=
dict
()
additional_decode_kwargs
=
{}
x
=
self
.
get_input
(
batch
)
additional_decode_kwargs
.
update
({
key
:
batch
[
key
]
for
key
in
self
.
additional_decode_keys
.
intersection
(
batch
)})
_
,
xrec
,
_
=
self
(
x
,
**
additional_decode_kwargs
)
log
[
"inputs"
]
=
x
log
[
"reconstructions"
]
=
xrec
diff
=
0.5
*
torch
.
abs
(
torch
.
clamp
(
xrec
,
-
1.0
,
1.0
)
-
x
)
diff
.
clamp_
(
0
,
1.0
)
log
[
"diff"
]
=
2.0
*
diff
-
1.0
# diff_boost shows location of small errors, by boosting their
# brightness.
log
[
"diff_boost"
]
=
2.0
*
torch
.
clamp
(
self
.
diff_boost_factor
*
diff
,
0.0
,
1.0
)
-
1
if
hasattr
(
self
.
loss
,
"log_images"
):
log
.
update
(
self
.
loss
.
log_images
(
x
,
xrec
))
with
self
.
ema_scope
():
_
,
xrec_ema
,
_
=
self
(
x
,
**
additional_decode_kwargs
)
log
[
"reconstructions_ema"
]
=
xrec_ema
diff_ema
=
0.5
*
torch
.
abs
(
torch
.
clamp
(
xrec_ema
,
-
1.0
,
1.0
)
-
x
)
diff_ema
.
clamp_
(
0
,
1.0
)
log
[
"diff_ema"
]
=
2.0
*
diff_ema
-
1.0
log
[
"diff_boost_ema"
]
=
2.0
*
torch
.
clamp
(
self
.
diff_boost_factor
*
diff_ema
,
0.0
,
1.0
)
-
1
if
additional_log_kwargs
:
additional_decode_kwargs
.
update
(
additional_log_kwargs
)
_
,
xrec_add
,
_
=
self
(
x
,
**
additional_decode_kwargs
)
log_str
=
"reconstructions-"
+
"-"
.
join
(
[
f
"
{
key
}
=
{
additional_log_kwargs
[
key
]
}
"
for
key
in
additional_log_kwargs
]
)
log
[
log_str
]
=
xrec_add
return
log
class
AutoencodingEngineLegacy
(
AutoencodingEngine
):
def
__init__
(
self
,
embed_dim
:
int
,
**
kwargs
):
self
.
max_batch_size
=
kwargs
.
pop
(
"max_batch_size"
,
None
)
ddconfig
=
kwargs
.
pop
(
"ddconfig"
)
ckpt_path
=
kwargs
.
pop
(
"ckpt_path"
,
None
)
ckpt_engine
=
kwargs
.
pop
(
"ckpt_engine"
,
None
)
super
().
__init__
(
encoder_config
=
{
"target"
:
"sgm.modules.diffusionmodules.model.Encoder"
,
"params"
:
ddconfig
,
},
decoder_config
=
{
"target"
:
"sgm.modules.diffusionmodules.model.Decoder"
,
"params"
:
ddconfig
,
},
**
kwargs
,
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
(
1
+
ddconfig
[
"double_z"
])
*
ddconfig
[
"z_channels"
],
(
1
+
ddconfig
[
"double_z"
])
*
embed_dim
,
1
,
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
ddconfig
[
"z_channels"
],
1
)
self
.
embed_dim
=
embed_dim
self
.
apply_ckpt
(
default
(
ckpt_path
,
ckpt_engine
))
def
get_autoencoder_params
(
self
)
->
list
:
params
=
super
().
get_autoencoder_params
()
return
params
def
encode
(
self
,
x
:
torch
.
Tensor
,
return_reg_log
:
bool
=
False
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
dict
]]:
if
self
.
max_batch_size
is
None
:
z
=
self
.
encoder
(
x
)
z
=
self
.
quant_conv
(
z
)
else
:
N
=
x
.
shape
[
0
]
bs
=
self
.
max_batch_size
n_batches
=
int
(
math
.
ceil
(
N
/
bs
))
z
=
list
()
for
i_batch
in
range
(
n_batches
):
z_batch
=
self
.
encoder
(
x
[
i_batch
*
bs
:
(
i_batch
+
1
)
*
bs
])
z_batch
=
self
.
quant_conv
(
z_batch
)
z
.
append
(
z_batch
)
z
=
torch
.
cat
(
z
,
0
)
z
,
reg_log
=
self
.
regularization
(
z
)
if
return_reg_log
:
return
z
,
reg_log
return
z
def
decode
(
self
,
z
:
torch
.
Tensor
,
**
decoder_kwargs
)
->
torch
.
Tensor
:
if
self
.
max_batch_size
is
None
:
dec
=
self
.
post_quant_conv
(
z
)
dec
=
self
.
decoder
(
dec
,
**
decoder_kwargs
)
else
:
N
=
z
.
shape
[
0
]
bs
=
self
.
max_batch_size
n_batches
=
int
(
math
.
ceil
(
N
/
bs
))
dec
=
list
()
for
i_batch
in
range
(
n_batches
):
dec_batch
=
self
.
post_quant_conv
(
z
[
i_batch
*
bs
:
(
i_batch
+
1
)
*
bs
])
dec_batch
=
self
.
decoder
(
dec_batch
,
**
decoder_kwargs
)
dec
.
append
(
dec_batch
)
dec
=
torch
.
cat
(
dec
,
0
)
return
dec
class
IdentityFirstStage
(
AbstractAutoencoder
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
get_input
(
self
,
x
:
Any
)
->
Any
:
return
x
def
encode
(
self
,
x
:
Any
,
*
args
,
**
kwargs
)
->
Any
:
return
x
def
decode
(
self
,
x
:
Any
,
*
args
,
**
kwargs
)
->
Any
:
return
class
VideoAutoencodingEngine
(
AutoencodingEngine
):
def
__init__
(
self
,
ckpt_path
:
Union
[
None
,
str
]
=
None
,
ignore_keys
:
Union
[
Tuple
,
list
]
=
(),
image_video_weights
=
[
1
,
1
],
only_train_decoder
=
False
,
context_parallel_size
=
0
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
context_parallel_size
=
context_parallel_size
if
ckpt_path
is
not
None
:
self
.
init_from_ckpt
(
ckpt_path
,
ignore_keys
=
ignore_keys
)
def
log_videos
(
self
,
batch
:
dict
,
additional_log_kwargs
:
Optional
[
Dict
]
=
None
,
**
kwargs
)
->
dict
:
return
self
.
log_images
(
batch
,
additional_log_kwargs
,
**
kwargs
)
def
get_input
(
self
,
batch
:
dict
)
->
torch
.
Tensor
:
if
self
.
context_parallel_size
>
0
:
if
not
is_context_parallel_initialized
():
initialize_context_parallel
(
self
.
context_parallel_size
)
batch
=
batch
[
self
.
input_key
]
global_src_rank
=
get_context_parallel_group_rank
()
*
self
.
context_parallel_size
torch
.
distributed
.
broadcast
(
batch
,
src
=
global_src_rank
,
group
=
get_context_parallel_group
())
batch
=
_conv_split
(
batch
,
dim
=
2
,
kernel_size
=
1
)
return
batch
return
batch
[
self
.
input_key
]
def
apply_ckpt
(
self
,
ckpt
:
Union
[
None
,
str
,
dict
]):
if
ckpt
is
None
:
return
self
.
init_from_ckpt
(
ckpt
)
def
init_from_ckpt
(
self
,
path
,
ignore_keys
=
list
()):
sd
=
torch
.
load
(
path
,
map_location
=
"cpu"
)[
"state_dict"
]
keys
=
list
(
sd
.
keys
())
for
k
in
keys
:
for
ik
in
ignore_keys
:
if
k
.
startswith
(
ik
):
del
sd
[
k
]
missing_keys
,
unexpected_keys
=
self
.
load_state_dict
(
sd
,
strict
=
False
)
print
(
"Missing keys: "
,
missing_keys
)
print
(
"Unexpected keys: "
,
unexpected_keys
)
print
(
f
"Restored from
{
path
}
"
)
class
VideoAutoencoderInferenceWrapper
(
VideoAutoencodingEngine
):
def
__init__
(
self
,
cp_size
=
0
,
*
args
,
**
kwargs
,
):
self
.
cp_size
=
cp_size
return
super
().
__init__
(
*
args
,
**
kwargs
)
def
encode
(
self
,
x
:
torch
.
Tensor
,
return_reg_log
:
bool
=
False
,
unregularized
:
bool
=
False
,
input_cp
:
bool
=
False
,
output_cp
:
bool
=
False
,
use_cp
:
bool
=
True
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
dict
]]:
if
self
.
cp_size
<=
1
:
use_cp
=
False
if
self
.
cp_size
>
0
and
use_cp
and
not
input_cp
:
if
not
is_context_parallel_initialized
:
initialize_context_parallel
(
self
.
cp_size
)
global_src_rank
=
get_context_parallel_group_rank
()
*
self
.
cp_size
torch
.
distributed
.
broadcast
(
x
,
src
=
global_src_rank
,
group
=
get_context_parallel_group
())
x
=
_conv_split
(
x
,
dim
=
2
,
kernel_size
=
1
)
if
return_reg_log
:
z
,
reg_log
=
super
().
encode
(
x
,
return_reg_log
,
unregularized
,
use_cp
=
use_cp
)
else
:
z
=
super
().
encode
(
x
,
return_reg_log
,
unregularized
,
use_cp
=
use_cp
)
if
self
.
cp_size
>
0
and
use_cp
and
not
output_cp
:
z
=
_conv_gather
(
z
,
dim
=
2
,
kernel_size
=
1
)
if
return_reg_log
:
return
z
,
reg_log
return
z
def
decode
(
self
,
z
:
torch
.
Tensor
,
input_cp
:
bool
=
False
,
output_cp
:
bool
=
False
,
use_cp
:
bool
=
True
,
**
kwargs
,
):
if
self
.
cp_size
<=
1
:
use_cp
=
False
if
self
.
cp_size
>
0
and
use_cp
and
not
input_cp
:
if
not
is_context_parallel_initialized
:
initialize_context_parallel
(
self
.
cp_size
)
global_src_rank
=
get_context_parallel_group_rank
()
*
self
.
cp_size
torch
.
distributed
.
broadcast
(
z
,
src
=
global_src_rank
,
group
=
get_context_parallel_group
())
z
=
_conv_split
(
z
,
dim
=
2
,
kernel_size
=
1
)
x
=
super
().
decode
(
z
,
use_cp
=
use_cp
,
**
kwargs
)
if
self
.
cp_size
>
0
and
use_cp
and
not
output_cp
:
x
=
_conv_gather
(
x
,
dim
=
2
,
kernel_size
=
1
)
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
,
input_cp
:
bool
=
False
,
latent_cp
:
bool
=
False
,
output_cp
:
bool
=
False
,
**
additional_decode_kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
dict
]:
z
,
reg_log
=
self
.
encode
(
x
,
return_reg_log
=
True
,
input_cp
=
input_cp
,
output_cp
=
latent_cp
)
dec
=
self
.
decode
(
z
,
input_cp
=
latent_cp
,
output_cp
=
output_cp
,
**
additional_decode_kwargs
)
return
z
,
dec
,
reg_log
cogvideox-based/sat/sgm/modules/__init__.py
0 → 100644
View file @
1f5da520
from
.encoders.modules
import
GeneralConditioner
UNCONDITIONAL_CONFIG
=
{
"target"
:
"sgm.modules.GeneralConditioner"
,
"params"
:
{
"emb_models"
:
[]},
}
cogvideox-based/sat/sgm/modules/__pycache__/__init__.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
cogvideox-based/sat/sgm/modules/__pycache__/attention.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
cogvideox-based/sat/sgm/modules/__pycache__/cp_enc_dec.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
cogvideox-based/sat/sgm/modules/__pycache__/ema.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
cogvideox-based/sat/sgm/modules/__pycache__/fuse_sft_block.cpython-39.pyc
0 → 100644
View file @
1f5da520
File added
Prev
1
2
3
4
5
6
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment