Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
STAR
Commits
1f5da520
Commit
1f5da520
authored
Dec 05, 2025
by
yangzhong
Browse files
git init
parents
Pipeline
#3144
failed with stages
in 0 seconds
Changes
326
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
708 additions
and
0 deletions
+708
-0
video_to_video/utils/__pycache__/seed.cpython-310.pyc
video_to_video/utils/__pycache__/seed.cpython-310.pyc
+0
-0
video_to_video/utils/config.py
video_to_video/utils/config.py
+169
-0
video_to_video/utils/logger.py
video_to_video/utils/logger.py
+95
-0
video_to_video/utils/seed.py
video_to_video/utils/seed.py
+14
-0
video_to_video/video_to_video_model.py
video_to_video/video_to_video_model.py
+212
-0
video_to_video/video_to_video_model_local.py
video_to_video/video_to_video_model_local.py
+218
-0
No files found.
video_to_video/utils/__pycache__/seed.cpython-310.pyc
0 → 100644
View file @
1f5da520
File added
video_to_video/utils/config.py
0 → 100644
View file @
1f5da520
# Copyright (c) Alibaba, Inc. and its affiliates.
import
logging
import
os
import
os.path
as
osp
from
datetime
import
datetime
import
torch
from
easydict
import
EasyDict
cfg
=
EasyDict
(
__name__
=
'Config: VideoLDM Decoder'
)
# ---------------------------work dir--------------------------
cfg
.
work_dir
=
'workspace/'
# ---------------------------Global Variable-----------------------------------
cfg
.
resolution
=
[
448
,
256
]
cfg
.
max_frames
=
32
# -----------------------------------------------------------------------------
# ---------------------------Dataset Parameter---------------------------------
cfg
.
mean
=
[
0.5
,
0.5
,
0.5
]
cfg
.
std
=
[
0.5
,
0.5
,
0.5
]
cfg
.
max_words
=
1000
# PlaceHolder
cfg
.
vit_out_dim
=
1024
cfg
.
vit_resolution
=
[
224
,
224
]
cfg
.
depth_clamp
=
10.0
cfg
.
misc_size
=
384
cfg
.
depth_std
=
20.0
cfg
.
frame_lens
=
32
cfg
.
sample_fps
=
8
cfg
.
batch_sizes
=
1
# -----------------------------------------------------------------------------
# ---------------------------Mode Parameters-----------------------------------
# Diffusion
cfg
.
schedule
=
'cosine'
cfg
.
num_timesteps
=
1000
cfg
.
mean_type
=
'v'
cfg
.
var_type
=
'fixed_small'
cfg
.
loss_type
=
'mse'
cfg
.
ddim_timesteps
=
50
cfg
.
ddim_eta
=
0.0
cfg
.
clamp
=
1.0
cfg
.
share_noise
=
False
cfg
.
use_div_loss
=
False
cfg
.
noise_strength
=
0.1
# classifier-free guidance
cfg
.
p_zero
=
0.1
cfg
.
guide_scale
=
3.0
# clip vision encoder
cfg
.
vit_mean
=
[
0.48145466
,
0.4578275
,
0.40821073
]
cfg
.
vit_std
=
[
0.26862954
,
0.26130258
,
0.27577711
]
# Model
cfg
.
scale_factor
=
0.18215
cfg
.
use_fp16
=
True
cfg
.
temporal_attention
=
True
cfg
.
decoder_bs
=
8
cfg
.
UNet
=
{
'type'
:
'Vid2VidSDUNet'
,
'in_dim'
:
4
,
'dim'
:
320
,
'y_dim'
:
cfg
.
vit_out_dim
,
'context_dim'
:
1024
,
'out_dim'
:
8
if
cfg
.
var_type
.
startswith
(
'learned'
)
else
4
,
'dim_mult'
:
[
1
,
2
,
4
,
4
],
'num_heads'
:
8
,
'head_dim'
:
64
,
'num_res_blocks'
:
2
,
'attn_scales'
:
[
1
/
1
,
1
/
2
,
1
/
4
],
'dropout'
:
0.1
,
'temporal_attention'
:
cfg
.
temporal_attention
,
'temporal_attn_times'
:
1
,
'use_checkpoint'
:
False
,
'use_fps_condition'
:
False
,
'use_sim_mask'
:
False
,
'num_tokens'
:
4
,
'default_fps'
:
8
,
'input_dim'
:
1024
}
cfg
.
guidances
=
[]
# auotoencoder from stabel diffusion
cfg
.
auto_encoder
=
{
'type'
:
'AutoencoderKL'
,
'ddconfig'
:
{
'double_z'
:
True
,
'z_channels'
:
4
,
'resolution'
:
256
,
'in_channels'
:
3
,
'out_ch'
:
3
,
'ch'
:
128
,
'ch_mult'
:
[
1
,
2
,
4
,
4
],
'num_res_blocks'
:
2
,
'attn_resolutions'
:
[],
'dropout'
:
0.0
},
'embed_dim'
:
4
,
'pretrained'
:
'models/v2-1_512-ema-pruned.ckpt'
}
# clip embedder
cfg
.
embedder
=
{
'type'
:
'FrozenOpenCLIPEmbedder'
,
'layer'
:
'penultimate'
,
'vit_resolution'
:
[
224
,
224
],
'pretrained'
:
'open_clip_pytorch_model.bin'
}
# -----------------------------------------------------------------------------
# ---------------------------Training Settings---------------------------------
# training and optimizer
cfg
.
ema_decay
=
0.9999
cfg
.
num_steps
=
600000
cfg
.
lr
=
5e-5
cfg
.
weight_decay
=
0.0
cfg
.
betas
=
(
0.9
,
0.999
)
cfg
.
eps
=
1.0e-8
cfg
.
chunk_size
=
16
cfg
.
alpha
=
0.7
cfg
.
save_ckp_interval
=
1000
# -----------------------------------------------------------------------------
# ----------------------------Pretrain Settings---------------------------------
# Default: load 2d pretrain
cfg
.
fix_weight
=
False
cfg
.
load_match
=
False
cfg
.
pretrained_checkpoint
=
'v2-1_512-ema-pruned.ckpt'
cfg
.
pretrained_image_keys
=
'stable_diffusion_image_key_temporal_attention_x1.json'
cfg
.
resume_checkpoint
=
'img2video_ldm_0779000.pth'
# -----------------------------------------------------------------------------
# -----------------------------Visual-------------------------------------------
# Visual videos
cfg
.
viz_interval
=
1000
cfg
.
visual_train
=
{
'type'
:
'VisualVideoTextDuringTrain'
,
}
cfg
.
visual_inference
=
{
'type'
:
'VisualGeneratedVideos'
,
}
cfg
.
inference_list_path
=
''
# logging
cfg
.
log_interval
=
100
# Default log_dir
cfg
.
log_dir
=
'workspace/output_data'
# -----------------------------------------------------------------------------
# ---------------------------Others--------------------------------------------
# seed
cfg
.
seed
=
8888
cfg
.
negative_prompt
=
'painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon,
\
CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark,
\
signature, jpeg artifacts, deformed, lowres, over-smooth'
cfg
.
positive_prompt
=
'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera,
\
hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing,
\
skin pore detailing, hyper sharpness, perfect without deformations.'
video_to_video/utils/logger.py
0 → 100644
View file @
1f5da520
# Copyright (c) Alibaba, Inc. and its affiliates.
import
importlib
import
logging
from
typing
import
Optional
from
torch
import
distributed
as
dist
init_loggers
=
{}
formatter
=
logging
.
Formatter
(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
def
get_logger
(
log_file
:
Optional
[
str
]
=
None
,
log_level
:
int
=
logging
.
INFO
,
file_mode
:
str
=
'w'
):
""" Get logging logger
Args:
log_file: Log filename, if specified, file handler will be added to
logger
log_level: Logging level.
file_mode: Specifies the mode to open the file, if filename is
specified (if filemode is unspecified, it defaults to 'w').
"""
logger_name
=
__name__
.
split
(
'.'
)[
0
]
logger
=
logging
.
getLogger
(
logger_name
)
logger
.
propagate
=
False
if
logger_name
in
init_loggers
:
add_file_handler_if_needed
(
logger
,
log_file
,
file_mode
,
log_level
)
return
logger
# handle duplicate logs to the console
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
# to the root logger. As logger.propagate is True by default, this root
# level handler causes logging messages from rank>0 processes to
# unexpectedly show up on the console, creating much unwanted clutter.
# To fix this issue, we set the root logger's StreamHandler, if any, to log
# at the ERROR level.
for
handler
in
logger
.
root
.
handlers
:
if
type
(
handler
)
is
logging
.
StreamHandler
:
handler
.
setLevel
(
logging
.
ERROR
)
stream_handler
=
logging
.
StreamHandler
()
handlers
=
[
stream_handler
]
if
importlib
.
util
.
find_spec
(
'torch'
)
is
not
None
:
is_worker0
=
is_master
()
else
:
is_worker0
=
True
if
is_worker0
and
log_file
is
not
None
:
file_handler
=
logging
.
FileHandler
(
log_file
,
file_mode
)
handlers
.
append
(
file_handler
)
for
handler
in
handlers
:
handler
.
setFormatter
(
formatter
)
handler
.
setLevel
(
log_level
)
logger
.
addHandler
(
handler
)
if
is_worker0
:
logger
.
setLevel
(
log_level
)
else
:
logger
.
setLevel
(
logging
.
ERROR
)
init_loggers
[
logger_name
]
=
True
return
logger
def
add_file_handler_if_needed
(
logger
,
log_file
,
file_mode
,
log_level
):
for
handler
in
logger
.
handlers
:
if
isinstance
(
handler
,
logging
.
FileHandler
):
return
if
importlib
.
util
.
find_spec
(
'torch'
)
is
not
None
:
is_worker0
=
is_master
()
else
:
is_worker0
=
True
if
is_worker0
and
log_file
is
not
None
:
file_handler
=
logging
.
FileHandler
(
log_file
,
file_mode
)
file_handler
.
setFormatter
(
formatter
)
file_handler
.
setLevel
(
log_level
)
logger
.
addHandler
(
file_handler
)
def
is_master
(
group
=
None
):
return
dist
.
get_rank
(
group
)
==
0
if
is_dist
()
else
True
def
is_dist
():
return
dist
.
is_available
()
and
dist
.
is_initialized
()
\ No newline at end of file
video_to_video/utils/seed.py
0 → 100644
View file @
1f5da520
# Copyright (c) Alibaba, Inc. and its affiliates.
import
random
import
numpy
as
np
import
torch
def
setup_seed
(
seed
):
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
video_to_video/video_to_video_model.py
0 → 100644
View file @
1f5da520
import
os
os
.
environ
[
'CURL_CA_BUNDLE'
]
=
''
os
.
environ
[
'HF_ENDPOINT'
]
=
'https://hf-mirror.com'
import
os.path
as
osp
import
random
from
typing
import
Any
,
Dict
import
torch
import
torch.cuda.amp
as
amp
import
torch.nn.functional
as
F
from
video_to_video.modules
import
*
from
video_to_video.utils.config
import
cfg
from
video_to_video.diffusion.diffusion_sdedit
import
GaussianDiffusion
from
video_to_video.diffusion.schedules_sdedit
import
noise_schedule
from
video_to_video.utils.logger
import
get_logger
from
diffusers
import
AutoencoderKLTemporalDecoder
logger
=
get_logger
()
class
VideoToVideo_sr
():
def
__init__
(
self
,
opt
,
device
=
torch
.
device
(
f
'cuda:0'
)):
self
.
opt
=
opt
self
.
device
=
device
# torch.device(f'cuda:0')
# text_encoder
text_encoder
=
FrozenOpenCLIPEmbedder
(
device
=
self
.
device
,
pretrained
=
"laion2b_s32b_b79k"
)
text_encoder
.
model
.
to
(
self
.
device
)
self
.
text_encoder
=
text_encoder
logger
.
info
(
f
'Build encoder with FrozenOpenCLIPEmbedder'
)
# U-Net with ControlNet
generator
=
ControlledV2VUNet
()
generator
=
generator
.
to
(
self
.
device
)
generator
.
eval
()
cfg
.
model_path
=
opt
.
model_path
load_dict
=
torch
.
load
(
cfg
.
model_path
,
map_location
=
'cpu'
)
if
'state_dict'
in
load_dict
:
load_dict
=
load_dict
[
'state_dict'
]
ret
=
generator
.
load_state_dict
(
load_dict
,
strict
=
False
)
self
.
generator
=
generator
.
half
()
logger
.
info
(
'Load model path {}, with local status {}'
.
format
(
cfg
.
model_path
,
ret
))
# Noise scheduler
sigmas
=
noise_schedule
(
schedule
=
'logsnr_cosine_interp'
,
n
=
1000
,
zero_terminal_snr
=
True
,
scale_min
=
2.0
,
scale_max
=
4.0
)
diffusion
=
GaussianDiffusion
(
sigmas
=
sigmas
)
self
.
diffusion
=
diffusion
logger
.
info
(
'Build diffusion with GaussianDiffusion'
)
# Temporal VAE
vae
=
AutoencoderKLTemporalDecoder
.
from_pretrained
(
"stabilityai/stable-video-diffusion-img2vid"
,
subfolder
=
"vae"
,
variant
=
"fp16"
)
vae
.
eval
()
vae
.
requires_grad_
(
False
)
vae
.
to
(
self
.
device
)
self
.
vae
=
vae
logger
.
info
(
'Build Temporal VAE'
)
torch
.
cuda
.
empty_cache
()
self
.
negative_prompt
=
cfg
.
negative_prompt
self
.
positive_prompt
=
cfg
.
positive_prompt
negative_y
=
text_encoder
(
self
.
negative_prompt
).
detach
()
self
.
negative_y
=
negative_y
def
test
(
self
,
input
:
Dict
[
str
,
Any
],
total_noise_levels
=
1000
,
\
steps
=
50
,
solver_mode
=
'fast'
,
guide_scale
=
7.5
,
max_chunk_len
=
32
):
video_data
=
input
[
'video_data'
]
y
=
input
[
'y'
]
(
target_h
,
target_w
)
=
input
[
'target_res'
]
video_data
=
F
.
interpolate
(
video_data
,
[
target_h
,
target_w
],
mode
=
'bilinear'
)
logger
.
info
(
f
'video_data shape:
{
video_data
.
shape
}
'
)
frames_num
,
_
,
h
,
w
=
video_data
.
shape
padding
=
pad_to_fit
(
h
,
w
)
video_data
=
F
.
pad
(
video_data
,
padding
,
'constant'
,
1
)
video_data
=
video_data
.
unsqueeze
(
0
)
bs
=
1
video_data
=
video_data
.
to
(
self
.
device
)
video_data_feature
=
self
.
vae_encode
(
video_data
)
torch
.
cuda
.
empty_cache
()
y
=
self
.
text_encoder
(
y
).
detach
()
with
amp
.
autocast
(
enabled
=
True
):
t
=
torch
.
LongTensor
([
total_noise_levels
-
1
]).
to
(
self
.
device
)
noised_lr
=
self
.
diffusion
.
diffuse
(
video_data_feature
,
t
)
model_kwargs
=
[{
'y'
:
y
},
{
'y'
:
self
.
negative_y
}]
model_kwargs
.
append
({
'hint'
:
video_data_feature
})
torch
.
cuda
.
empty_cache
()
chunk_inds
=
make_chunks
(
frames_num
,
interp_f_num
=
0
,
max_chunk_len
=
max_chunk_len
)
if
frames_num
>
max_chunk_len
else
None
solver
=
'dpmpp_2m_sde'
# 'heun' | 'dpmpp_2m_sde'
gen_vid
=
self
.
diffusion
.
sample_sr
(
noise
=
noised_lr
,
model
=
self
.
generator
,
model_kwargs
=
model_kwargs
,
guide_scale
=
guide_scale
,
guide_rescale
=
0.2
,
solver
=
solver
,
solver_mode
=
solver_mode
,
return_intermediate
=
None
,
steps
=
steps
,
t_max
=
total_noise_levels
-
1
,
t_min
=
0
,
discretization
=
'trailing'
,
chunk_inds
=
chunk_inds
,)
torch
.
cuda
.
empty_cache
()
logger
.
info
(
f
'sampling, finished.'
)
vid_tensor_gen
=
self
.
vae_decode_chunk
(
gen_vid
,
chunk_size
=
3
)
logger
.
info
(
f
'temporal vae decoding, finished.'
)
w1
,
w2
,
h1
,
h2
=
padding
vid_tensor_gen
=
vid_tensor_gen
[:,:,
h1
:
h
+
h1
,
w1
:
w
+
w1
]
gen_video
=
rearrange
(
vid_tensor_gen
,
'(b f) c h w -> b c f h w'
,
b
=
bs
)
torch
.
cuda
.
empty_cache
()
return
gen_video
.
type
(
torch
.
float32
).
cpu
()
def
temporal_vae_decode
(
self
,
z
,
num_f
):
return
self
.
vae
.
decode
(
z
/
self
.
vae
.
config
.
scaling_factor
,
num_frames
=
num_f
).
sample
def
vae_decode_chunk
(
self
,
z
,
chunk_size
=
3
):
z
=
rearrange
(
z
,
"b c f h w -> (b f) c h w"
)
video
=
[]
for
ind
in
range
(
0
,
z
.
shape
[
0
],
chunk_size
):
num_f
=
z
[
ind
:
ind
+
chunk_size
].
shape
[
0
]
video
.
append
(
self
.
temporal_vae_decode
(
z
[
ind
:
ind
+
chunk_size
],
num_f
))
video
=
torch
.
cat
(
video
)
return
video
def
vae_encode
(
self
,
t
,
chunk_size
=
1
):
num_f
=
t
.
shape
[
1
]
t
=
rearrange
(
t
,
"b f c h w -> (b f) c h w"
)
z_list
=
[]
for
ind
in
range
(
0
,
t
.
shape
[
0
],
chunk_size
):
z_list
.
append
(
self
.
vae
.
encode
(
t
[
ind
:
ind
+
chunk_size
]).
latent_dist
.
sample
())
z
=
torch
.
cat
(
z_list
,
dim
=
0
)
z
=
rearrange
(
z
,
"(b f) c h w -> b c f h w"
,
f
=
num_f
)
return
z
*
self
.
vae
.
config
.
scaling_factor
def
pad_to_fit
(
h
,
w
):
BEST_H
,
BEST_W
=
720
,
1280
if
h
<
BEST_H
:
h1
,
h2
=
_create_pad
(
h
,
BEST_H
)
elif
h
==
BEST_H
:
h1
=
h2
=
0
else
:
h1
=
0
h2
=
int
((
h
+
48
)
//
64
*
64
)
+
64
-
48
-
h
if
w
<
BEST_W
:
w1
,
w2
=
_create_pad
(
w
,
BEST_W
)
elif
w
==
BEST_W
:
w1
=
w2
=
0
else
:
w1
=
0
w2
=
int
(
w
//
64
*
64
)
+
64
-
w
return
(
w1
,
w2
,
h1
,
h2
)
def
_create_pad
(
h
,
max_len
):
h1
=
int
((
max_len
-
h
)
//
2
)
h2
=
max_len
-
h1
-
h
return
h1
,
h2
def
make_chunks
(
f_num
,
interp_f_num
,
max_chunk_len
,
chunk_overlap_ratio
=
0.5
):
MAX_CHUNK_LEN
=
max_chunk_len
MAX_O_LEN
=
MAX_CHUNK_LEN
*
chunk_overlap_ratio
chunk_len
=
int
((
MAX_CHUNK_LEN
-
1
)
//
(
1
+
interp_f_num
)
*
(
interp_f_num
+
1
)
+
1
)
o_len
=
int
((
MAX_O_LEN
-
1
)
//
(
1
+
interp_f_num
)
*
(
interp_f_num
+
1
)
+
1
)
chunk_inds
=
sliding_windows_1d
(
f_num
,
chunk_len
,
o_len
)
return
chunk_inds
def
sliding_windows_1d
(
length
,
window_size
,
overlap_size
):
stride
=
window_size
-
overlap_size
ind
=
0
coords
=
[]
while
ind
<
length
:
if
ind
+
window_size
*
1.25
>=
length
:
coords
.
append
((
ind
,
length
))
break
else
:
coords
.
append
((
ind
,
ind
+
window_size
))
ind
+=
stride
return
coords
video_to_video/video_to_video_model_local.py
0 → 100644
View file @
1f5da520
import
os
os
.
environ
[
'CURL_CA_BUNDLE'
]
=
''
os
.
environ
[
'HF_ENDPOINT'
]
=
'https://hf-mirror.com'
import
os.path
as
osp
import
random
from
typing
import
Any
,
Dict
import
torch
import
torch.cuda.amp
as
amp
import
torch.nn.functional
as
F
from
video_to_video.modules
import
*
from
video_to_video.utils.config
import
cfg
from
video_to_video.diffusion.diffusion_sdedit
import
GaussianDiffusion
from
video_to_video.diffusion.schedules_sdedit
import
noise_schedule
from
video_to_video.utils.logger
import
get_logger
from
diffusers
import
AutoencoderKLTemporalDecoder
logger
=
get_logger
()
class
VideoToVideo_sr
():
def
__init__
(
self
,
opt
,
device
=
torch
.
device
(
f
'cuda:0'
)):
self
.
opt
=
opt
self
.
device
=
device
# torch.device(f'cuda:0')
# text_encoder
local_clip_dir
=
"/yangzhong/STAR/pretrained_weight/laion/"
assert
os
.
path
.
exists
(
local_clip_dir
),
f
"本地模型目录不存在:
{
local_clip_dir
}
"
#text_encoder = FrozenOpenCLIPEmbedder(device=self.device, pretrained="laion2b_s32b_b79k")
text_encoder
=
FrozenOpenCLIPEmbedder
(
device
=
self
.
device
,
pretrained
=
"laion2b_s32b_b79k"
,
cache_dir
=
local_clip_dir
,
force_download
=
False
,
local_files_only
=
True
)
text_encoder
.
model
.
to
(
self
.
device
)
self
.
text_encoder
=
text_encoder
#logger.info(f'Build encoder with FrozenOpenCLIPEmbedder')
logger
.
info
(
f
'Build encoder with FrozenOpenCLIPEmbedder (本地模型路径:
{
local_clip_dir
}
)'
)
# U-Net with ControlNet
generator
=
ControlledV2VUNet
()
generator
=
generator
.
to
(
self
.
device
)
generator
.
eval
()
cfg
.
model_path
=
opt
.
model_path
load_dict
=
torch
.
load
(
cfg
.
model_path
,
map_location
=
'cpu'
)
if
'state_dict'
in
load_dict
:
load_dict
=
load_dict
[
'state_dict'
]
ret
=
generator
.
load_state_dict
(
load_dict
,
strict
=
False
)
self
.
generator
=
generator
.
half
()
logger
.
info
(
'Load model path {}, with local status {}'
.
format
(
cfg
.
model_path
,
ret
))
# Noise scheduler
sigmas
=
noise_schedule
(
schedule
=
'logsnr_cosine_interp'
,
n
=
1000
,
zero_terminal_snr
=
True
,
scale_min
=
2.0
,
scale_max
=
4.0
)
diffusion
=
GaussianDiffusion
(
sigmas
=
sigmas
)
self
.
diffusion
=
diffusion
logger
.
info
(
'Build diffusion with GaussianDiffusion'
)
# Temporal VAE
vae
=
AutoencoderKLTemporalDecoder
.
from_pretrained
(
"stabilityai/stable-video-diffusion-img2vid"
,
subfolder
=
"vae"
,
variant
=
"fp16"
)
vae
.
eval
()
vae
.
requires_grad_
(
False
)
vae
.
to
(
self
.
device
)
self
.
vae
=
vae
logger
.
info
(
'Build Temporal VAE'
)
torch
.
cuda
.
empty_cache
()
self
.
negative_prompt
=
cfg
.
negative_prompt
self
.
positive_prompt
=
cfg
.
positive_prompt
negative_y
=
text_encoder
(
self
.
negative_prompt
).
detach
()
self
.
negative_y
=
negative_y
def
test
(
self
,
input
:
Dict
[
str
,
Any
],
total_noise_levels
=
1000
,
\
steps
=
50
,
solver_mode
=
'fast'
,
guide_scale
=
7.5
,
max_chunk_len
=
32
):
video_data
=
input
[
'video_data'
]
y
=
input
[
'y'
]
(
target_h
,
target_w
)
=
input
[
'target_res'
]
video_data
=
F
.
interpolate
(
video_data
,
[
target_h
,
target_w
],
mode
=
'bilinear'
)
logger
.
info
(
f
'video_data shape:
{
video_data
.
shape
}
'
)
frames_num
,
_
,
h
,
w
=
video_data
.
shape
padding
=
pad_to_fit
(
h
,
w
)
video_data
=
F
.
pad
(
video_data
,
padding
,
'constant'
,
1
)
video_data
=
video_data
.
unsqueeze
(
0
)
bs
=
1
video_data
=
video_data
.
to
(
self
.
device
)
video_data_feature
=
self
.
vae_encode
(
video_data
)
torch
.
cuda
.
empty_cache
()
y
=
self
.
text_encoder
(
y
).
detach
()
with
amp
.
autocast
(
enabled
=
True
):
t
=
torch
.
LongTensor
([
total_noise_levels
-
1
]).
to
(
self
.
device
)
noised_lr
=
self
.
diffusion
.
diffuse
(
video_data_feature
,
t
)
model_kwargs
=
[{
'y'
:
y
},
{
'y'
:
self
.
negative_y
}]
model_kwargs
.
append
({
'hint'
:
video_data_feature
})
torch
.
cuda
.
empty_cache
()
chunk_inds
=
make_chunks
(
frames_num
,
interp_f_num
=
0
,
max_chunk_len
=
max_chunk_len
)
if
frames_num
>
max_chunk_len
else
None
solver
=
'dpmpp_2m_sde'
# 'heun' | 'dpmpp_2m_sde'
gen_vid
=
self
.
diffusion
.
sample_sr
(
noise
=
noised_lr
,
model
=
self
.
generator
,
model_kwargs
=
model_kwargs
,
guide_scale
=
guide_scale
,
guide_rescale
=
0.2
,
solver
=
solver
,
solver_mode
=
solver_mode
,
return_intermediate
=
None
,
steps
=
steps
,
t_max
=
total_noise_levels
-
1
,
t_min
=
0
,
discretization
=
'trailing'
,
chunk_inds
=
chunk_inds
,)
torch
.
cuda
.
empty_cache
()
logger
.
info
(
f
'sampling, finished.'
)
vid_tensor_gen
=
self
.
vae_decode_chunk
(
gen_vid
,
chunk_size
=
3
)
logger
.
info
(
f
'temporal vae decoding, finished.'
)
w1
,
w2
,
h1
,
h2
=
padding
vid_tensor_gen
=
vid_tensor_gen
[:,:,
h1
:
h
+
h1
,
w1
:
w
+
w1
]
gen_video
=
rearrange
(
vid_tensor_gen
,
'(b f) c h w -> b c f h w'
,
b
=
bs
)
torch
.
cuda
.
empty_cache
()
return
gen_video
.
type
(
torch
.
float32
).
cpu
()
def
temporal_vae_decode
(
self
,
z
,
num_f
):
return
self
.
vae
.
decode
(
z
/
self
.
vae
.
config
.
scaling_factor
,
num_frames
=
num_f
).
sample
def
vae_decode_chunk
(
self
,
z
,
chunk_size
=
3
):
z
=
rearrange
(
z
,
"b c f h w -> (b f) c h w"
)
video
=
[]
for
ind
in
range
(
0
,
z
.
shape
[
0
],
chunk_size
):
num_f
=
z
[
ind
:
ind
+
chunk_size
].
shape
[
0
]
video
.
append
(
self
.
temporal_vae_decode
(
z
[
ind
:
ind
+
chunk_size
],
num_f
))
video
=
torch
.
cat
(
video
)
return
video
def
vae_encode
(
self
,
t
,
chunk_size
=
1
):
num_f
=
t
.
shape
[
1
]
t
=
rearrange
(
t
,
"b f c h w -> (b f) c h w"
)
z_list
=
[]
for
ind
in
range
(
0
,
t
.
shape
[
0
],
chunk_size
):
z_list
.
append
(
self
.
vae
.
encode
(
t
[
ind
:
ind
+
chunk_size
]).
latent_dist
.
sample
())
z
=
torch
.
cat
(
z_list
,
dim
=
0
)
z
=
rearrange
(
z
,
"(b f) c h w -> b c f h w"
,
f
=
num_f
)
return
z
*
self
.
vae
.
config
.
scaling_factor
def
pad_to_fit
(
h
,
w
):
BEST_H
,
BEST_W
=
720
,
1280
if
h
<
BEST_H
:
h1
,
h2
=
_create_pad
(
h
,
BEST_H
)
elif
h
==
BEST_H
:
h1
=
h2
=
0
else
:
h1
=
0
h2
=
int
((
h
+
48
)
//
64
*
64
)
+
64
-
48
-
h
if
w
<
BEST_W
:
w1
,
w2
=
_create_pad
(
w
,
BEST_W
)
elif
w
==
BEST_W
:
w1
=
w2
=
0
else
:
w1
=
0
w2
=
int
(
w
//
64
*
64
)
+
64
-
w
return
(
w1
,
w2
,
h1
,
h2
)
def
_create_pad
(
h
,
max_len
):
h1
=
int
((
max_len
-
h
)
//
2
)
h2
=
max_len
-
h1
-
h
return
h1
,
h2
def
make_chunks
(
f_num
,
interp_f_num
,
max_chunk_len
,
chunk_overlap_ratio
=
0.5
):
MAX_CHUNK_LEN
=
max_chunk_len
MAX_O_LEN
=
MAX_CHUNK_LEN
*
chunk_overlap_ratio
chunk_len
=
int
((
MAX_CHUNK_LEN
-
1
)
//
(
1
+
interp_f_num
)
*
(
interp_f_num
+
1
)
+
1
)
o_len
=
int
((
MAX_O_LEN
-
1
)
//
(
1
+
interp_f_num
)
*
(
interp_f_num
+
1
)
+
1
)
chunk_inds
=
sliding_windows_1d
(
f_num
,
chunk_len
,
o_len
)
return
chunk_inds
def
sliding_windows_1d
(
length
,
window_size
,
overlap_size
):
stride
=
window_size
-
overlap_size
ind
=
0
coords
=
[]
while
ind
<
length
:
if
ind
+
window_size
*
1.25
>=
length
:
coords
.
append
((
ind
,
length
))
break
else
:
coords
.
append
((
ind
,
ind
+
window_size
))
ind
+=
stride
return
coords
Prev
1
…
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment