Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
LatentSync
Commits
5c023842
Commit
5c023842
authored
Jan 14, 2025
by
chenpangpang
Browse files
feat: 增加LatentSync
parent
822b66ca
Pipeline
#2211
canceled with stages
Changes
112
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5238 additions
and
0 deletions
+5238
-0
LatentSync/latentsync/data/unet_dataset.py
LatentSync/latentsync/data/unet_dataset.py
+164
-0
LatentSync/latentsync/models/attention.py
LatentSync/latentsync/models/attention.py
+492
-0
LatentSync/latentsync/models/motion_module.py
LatentSync/latentsync/models/motion_module.py
+332
-0
LatentSync/latentsync/models/resnet.py
LatentSync/latentsync/models/resnet.py
+234
-0
LatentSync/latentsync/models/syncnet.py
LatentSync/latentsync/models/syncnet.py
+233
-0
LatentSync/latentsync/models/syncnet_wav2lip.py
LatentSync/latentsync/models/syncnet_wav2lip.py
+90
-0
LatentSync/latentsync/models/unet.py
LatentSync/latentsync/models/unet.py
+528
-0
LatentSync/latentsync/models/unet_blocks.py
LatentSync/latentsync/models/unet_blocks.py
+903
-0
LatentSync/latentsync/models/utils.py
LatentSync/latentsync/models/utils.py
+19
-0
LatentSync/latentsync/pipelines/lipsync_pipeline.py
LatentSync/latentsync/pipelines/lipsync_pipeline.py
+470
-0
LatentSync/latentsync/trepa/__init__.py
LatentSync/latentsync/trepa/__init__.py
+64
-0
LatentSync/latentsync/trepa/third_party/VideoMAEv2/__init__.py
...tSync/latentsync/trepa/third_party/VideoMAEv2/__init__.py
+0
-0
LatentSync/latentsync/trepa/third_party/VideoMAEv2/utils.py
LatentSync/latentsync/trepa/third_party/VideoMAEv2/utils.py
+81
-0
LatentSync/latentsync/trepa/third_party/VideoMAEv2/videomaev2_finetune.py
...tsync/trepa/third_party/VideoMAEv2/videomaev2_finetune.py
+539
-0
LatentSync/latentsync/trepa/third_party/VideoMAEv2/videomaev2_pretrain.py
...tsync/trepa/third_party/VideoMAEv2/videomaev2_pretrain.py
+469
-0
LatentSync/latentsync/trepa/third_party/__init__.py
LatentSync/latentsync/trepa/third_party/__init__.py
+0
-0
LatentSync/latentsync/trepa/utils/__init__.py
LatentSync/latentsync/trepa/utils/__init__.py
+0
-0
LatentSync/latentsync/trepa/utils/data_utils.py
LatentSync/latentsync/trepa/utils/data_utils.py
+321
-0
LatentSync/latentsync/trepa/utils/metric_utils.py
LatentSync/latentsync/trepa/utils/metric_utils.py
+161
-0
LatentSync/latentsync/utils/affine_transform.py
LatentSync/latentsync/utils/affine_transform.py
+138
-0
No files found.
LatentSync/latentsync/data/unet_dataset.py
0 → 100644
View file @
5c023842
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
numpy
as
np
from
torch.utils.data
import
Dataset
import
torch
import
random
import
cv2
from
..utils.image_processor
import
ImageProcessor
,
load_fixed_mask
from
..utils.audio
import
melspectrogram
from
decord
import
AudioReader
,
VideoReader
,
cpu
class
UNetDataset
(
Dataset
):
def
__init__
(
self
,
train_data_dir
:
str
,
config
):
if
config
.
data
.
train_fileslist
!=
""
:
with
open
(
config
.
data
.
train_fileslist
)
as
file
:
self
.
video_paths
=
[
line
.
rstrip
()
for
line
in
file
]
elif
train_data_dir
!=
""
:
self
.
video_paths
=
[]
for
file
in
os
.
listdir
(
train_data_dir
):
if
file
.
endswith
(
".mp4"
):
self
.
video_paths
.
append
(
os
.
path
.
join
(
train_data_dir
,
file
))
else
:
raise
ValueError
(
"data_dir and fileslist cannot be both empty"
)
self
.
resolution
=
config
.
data
.
resolution
self
.
num_frames
=
config
.
data
.
num_frames
if
self
.
num_frames
==
16
:
self
.
mel_window_length
=
52
elif
self
.
num_frames
==
5
:
self
.
mel_window_length
=
16
else
:
raise
NotImplementedError
(
"Only support 16 and 5 frames now"
)
self
.
audio_sample_rate
=
config
.
data
.
audio_sample_rate
self
.
video_fps
=
config
.
data
.
video_fps
self
.
mask
=
config
.
data
.
mask
self
.
mask_image
=
load_fixed_mask
(
self
.
resolution
)
self
.
load_audio_data
=
config
.
model
.
add_audio_layer
and
config
.
run
.
use_syncnet
self
.
audio_mel_cache_dir
=
config
.
data
.
audio_mel_cache_dir
os
.
makedirs
(
self
.
audio_mel_cache_dir
,
exist_ok
=
True
)
def
__len__
(
self
):
return
len
(
self
.
video_paths
)
def
read_audio
(
self
,
video_path
:
str
):
ar
=
AudioReader
(
video_path
,
ctx
=
cpu
(
self
.
worker_id
),
sample_rate
=
self
.
audio_sample_rate
)
original_mel
=
melspectrogram
(
ar
[:].
asnumpy
().
squeeze
(
0
))
return
torch
.
from_numpy
(
original_mel
)
def
crop_audio_window
(
self
,
original_mel
,
start_index
):
start_idx
=
int
(
80.0
*
(
start_index
/
float
(
self
.
video_fps
)))
end_idx
=
start_idx
+
self
.
mel_window_length
return
original_mel
[:,
start_idx
:
end_idx
].
unsqueeze
(
0
)
def
get_frames
(
self
,
video_reader
:
VideoReader
):
total_num_frames
=
len
(
video_reader
)
start_idx
=
random
.
randint
(
self
.
num_frames
//
2
,
total_num_frames
-
self
.
num_frames
-
self
.
num_frames
//
2
)
frames_index
=
np
.
arange
(
start_idx
,
start_idx
+
self
.
num_frames
,
dtype
=
int
)
while
True
:
wrong_start_idx
=
random
.
randint
(
0
,
total_num_frames
-
self
.
num_frames
)
if
wrong_start_idx
>
start_idx
-
self
.
num_frames
and
wrong_start_idx
<
start_idx
+
self
.
num_frames
:
continue
wrong_frames_index
=
np
.
arange
(
wrong_start_idx
,
wrong_start_idx
+
self
.
num_frames
,
dtype
=
int
)
break
frames
=
video_reader
.
get_batch
(
frames_index
).
asnumpy
()
wrong_frames
=
video_reader
.
get_batch
(
wrong_frames_index
).
asnumpy
()
return
frames
,
wrong_frames
,
start_idx
def
worker_init_fn
(
self
,
worker_id
):
# Initialize the face mesh object in each worker process,
# because the face mesh object cannot be called in subprocesses
self
.
worker_id
=
worker_id
setattr
(
self
,
f
"image_processor_
{
worker_id
}
"
,
ImageProcessor
(
self
.
resolution
,
self
.
mask
,
mask_image
=
self
.
mask_image
),
)
def
__getitem__
(
self
,
idx
):
image_processor
=
getattr
(
self
,
f
"image_processor_
{
self
.
worker_id
}
"
)
while
True
:
try
:
idx
=
random
.
randint
(
0
,
len
(
self
)
-
1
)
# Get video file path
video_path
=
self
.
video_paths
[
idx
]
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
self
.
worker_id
))
if
len
(
vr
)
<
3
*
self
.
num_frames
:
continue
continuous_frames
,
ref_frames
,
start_idx
=
self
.
get_frames
(
vr
)
if
self
.
load_audio_data
:
mel_cache_path
=
os
.
path
.
join
(
self
.
audio_mel_cache_dir
,
os
.
path
.
basename
(
video_path
).
replace
(
".mp4"
,
"_mel.pt"
)
)
if
os
.
path
.
isfile
(
mel_cache_path
):
try
:
original_mel
=
torch
.
load
(
mel_cache_path
)
except
Exception
as
e
:
print
(
f
"
{
type
(
e
).
__name__
}
-
{
e
}
-
{
mel_cache_path
}
"
)
os
.
remove
(
mel_cache_path
)
original_mel
=
self
.
read_audio
(
video_path
)
torch
.
save
(
original_mel
,
mel_cache_path
)
else
:
original_mel
=
self
.
read_audio
(
video_path
)
torch
.
save
(
original_mel
,
mel_cache_path
)
mel
=
self
.
crop_audio_window
(
original_mel
,
start_idx
)
if
mel
.
shape
[
-
1
]
!=
self
.
mel_window_length
:
continue
else
:
mel
=
[]
gt
,
masked_gt
,
mask
=
image_processor
.
prepare_masks_and_masked_images
(
continuous_frames
,
affine_transform
=
False
)
if
self
.
mask
==
"fix_mask"
:
ref
,
_
,
_
=
image_processor
.
prepare_masks_and_masked_images
(
ref_frames
,
affine_transform
=
False
)
else
:
ref
=
image_processor
.
process_images
(
ref_frames
)
vr
.
seek
(
0
)
# avoid memory leak
break
except
Exception
as
e
:
# Handle the exception of face not detcted
print
(
f
"
{
type
(
e
).
__name__
}
-
{
e
}
-
{
video_path
}
"
)
if
"vr"
in
locals
():
vr
.
seek
(
0
)
# avoid memory leak
sample
=
dict
(
gt
=
gt
,
masked_gt
=
masked_gt
,
ref
=
ref
,
mel
=
mel
,
mask
=
mask
,
video_path
=
video_path
,
start_idx
=
start_idx
,
)
return
sample
LatentSync/latentsync/models/attention.py
0 → 100644
View file @
5c023842
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
from
dataclasses
import
dataclass
from
turtle
import
forward
from
typing
import
Optional
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.modeling_utils
import
ModelMixin
from
diffusers.utils
import
BaseOutput
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.models.attention
import
CrossAttention
,
FeedForward
,
AdaLayerNorm
from
einops
import
rearrange
,
repeat
from
.utils
import
zero_module
@
dataclass
class
Transformer3DModelOutput
(
BaseOutput
):
sample
:
torch
.
FloatTensor
if
is_xformers_available
():
import
xformers
import
xformers.ops
else
:
xformers
=
None
class
Transformer3DModel
(
ModelMixin
,
ConfigMixin
):
@
register_to_config
def
__init__
(
self
,
num_attention_heads
:
int
=
16
,
attention_head_dim
:
int
=
88
,
in_channels
:
Optional
[
int
]
=
None
,
num_layers
:
int
=
1
,
dropout
:
float
=
0.0
,
norm_num_groups
:
int
=
32
,
cross_attention_dim
:
Optional
[
int
]
=
None
,
attention_bias
:
bool
=
False
,
activation_fn
:
str
=
"geglu"
,
num_embeds_ada_norm
:
Optional
[
int
]
=
None
,
use_linear_projection
:
bool
=
False
,
only_cross_attention
:
bool
=
False
,
upcast_attention
:
bool
=
False
,
use_motion_module
:
bool
=
False
,
unet_use_cross_frame_attention
=
None
,
unet_use_temporal_attention
=
None
,
add_audio_layer
=
False
,
audio_condition_method
=
"cross_attn"
,
custom_audio_layer
:
bool
=
False
,
):
super
().
__init__
()
self
.
use_linear_projection
=
use_linear_projection
self
.
num_attention_heads
=
num_attention_heads
self
.
attention_head_dim
=
attention_head_dim
inner_dim
=
num_attention_heads
*
attention_head_dim
# Define input layers
self
.
in_channels
=
in_channels
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
norm_num_groups
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
if
use_linear_projection
:
self
.
proj_in
=
nn
.
Linear
(
in_channels
,
inner_dim
)
else
:
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
if
not
custom_audio_layer
:
# Define transformers blocks
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
num_attention_heads
,
attention_head_dim
,
dropout
=
dropout
,
cross_attention_dim
=
cross_attention_dim
,
activation_fn
=
activation_fn
,
num_embeds_ada_norm
=
num_embeds_ada_norm
,
attention_bias
=
attention_bias
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
use_motion_module
=
use_motion_module
,
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
,
unet_use_temporal_attention
=
unet_use_temporal_attention
,
add_audio_layer
=
add_audio_layer
,
custom_audio_layer
=
custom_audio_layer
,
audio_condition_method
=
audio_condition_method
,
)
for
d
in
range
(
num_layers
)
]
)
else
:
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
AudioTransformerBlock
(
inner_dim
,
num_attention_heads
,
attention_head_dim
,
dropout
=
dropout
,
cross_attention_dim
=
cross_attention_dim
,
activation_fn
=
activation_fn
,
num_embeds_ada_norm
=
num_embeds_ada_norm
,
attention_bias
=
attention_bias
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
use_motion_module
=
use_motion_module
,
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
,
unet_use_temporal_attention
=
unet_use_temporal_attention
,
add_audio_layer
=
add_audio_layer
,
)
for
d
in
range
(
num_layers
)
]
)
# 4. Define output layers
if
use_linear_projection
:
self
.
proj_out
=
nn
.
Linear
(
in_channels
,
inner_dim
)
else
:
self
.
proj_out
=
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
if
custom_audio_layer
:
self
.
proj_out
=
zero_module
(
self
.
proj_out
)
def
forward
(
self
,
hidden_states
,
encoder_hidden_states
=
None
,
timestep
=
None
,
return_dict
:
bool
=
True
):
# Input
assert
hidden_states
.
dim
()
==
5
,
f
"Expected hidden_states to have ndim=5, but got ndim=
{
hidden_states
.
dim
()
}
."
video_length
=
hidden_states
.
shape
[
2
]
hidden_states
=
rearrange
(
hidden_states
,
"b c f h w -> (b f) c h w"
)
# No need to do this for audio input, because different audio samples are independent
# encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
batch
,
channel
,
height
,
weight
=
hidden_states
.
shape
residual
=
hidden_states
hidden_states
=
self
.
norm
(
hidden_states
)
if
not
self
.
use_linear_projection
:
hidden_states
=
self
.
proj_in
(
hidden_states
)
inner_dim
=
hidden_states
.
shape
[
1
]
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
batch
,
height
*
weight
,
inner_dim
)
else
:
inner_dim
=
hidden_states
.
shape
[
1
]
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
batch
,
height
*
weight
,
inner_dim
)
hidden_states
=
self
.
proj_in
(
hidden_states
)
# Blocks
for
block
in
self
.
transformer_blocks
:
hidden_states
=
block
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
timestep
=
timestep
,
video_length
=
video_length
,
)
# Output
if
not
self
.
use_linear_projection
:
hidden_states
=
hidden_states
.
reshape
(
batch
,
height
,
weight
,
inner_dim
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
hidden_states
=
self
.
proj_out
(
hidden_states
)
else
:
hidden_states
=
self
.
proj_out
(
hidden_states
)
hidden_states
=
hidden_states
.
reshape
(
batch
,
height
,
weight
,
inner_dim
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
output
=
hidden_states
+
residual
output
=
rearrange
(
output
,
"(b f) c h w -> b c f h w"
,
f
=
video_length
)
if
not
return_dict
:
return
(
output
,)
return
Transformer3DModelOutput
(
sample
=
output
)
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
num_attention_heads
:
int
,
attention_head_dim
:
int
,
dropout
=
0.0
,
cross_attention_dim
:
Optional
[
int
]
=
None
,
activation_fn
:
str
=
"geglu"
,
num_embeds_ada_norm
:
Optional
[
int
]
=
None
,
attention_bias
:
bool
=
False
,
only_cross_attention
:
bool
=
False
,
upcast_attention
:
bool
=
False
,
use_motion_module
:
bool
=
False
,
unet_use_cross_frame_attention
=
None
,
unet_use_temporal_attention
=
None
,
add_audio_layer
=
False
,
custom_audio_layer
=
False
,
audio_condition_method
=
"cross_attn"
,
):
super
().
__init__
()
self
.
only_cross_attention
=
only_cross_attention
self
.
use_ada_layer_norm
=
num_embeds_ada_norm
is
not
None
self
.
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
self
.
unet_use_temporal_attention
=
unet_use_temporal_attention
self
.
use_motion_module
=
use_motion_module
self
.
add_audio_layer
=
add_audio_layer
# SC-Attn
assert
unet_use_cross_frame_attention
is
not
None
if
unet_use_cross_frame_attention
:
raise
NotImplementedError
(
"SparseCausalAttention2D not implemented yet."
)
else
:
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
num_attention_heads
,
dim_head
=
attention_head_dim
,
dropout
=
dropout
,
bias
=
attention_bias
,
upcast_attention
=
upcast_attention
,
)
self
.
norm1
=
AdaLayerNorm
(
dim
,
num_embeds_ada_norm
)
if
self
.
use_ada_layer_norm
else
nn
.
LayerNorm
(
dim
)
# Cross-Attn
if
add_audio_layer
and
audio_condition_method
==
"cross_attn"
and
not
custom_audio_layer
:
self
.
audio_cross_attn
=
AudioCrossAttn
(
dim
=
dim
,
cross_attention_dim
=
cross_attention_dim
,
num_attention_heads
=
num_attention_heads
,
attention_head_dim
=
attention_head_dim
,
dropout
=
dropout
,
attention_bias
=
attention_bias
,
upcast_attention
=
upcast_attention
,
num_embeds_ada_norm
=
num_embeds_ada_norm
,
use_ada_layer_norm
=
self
.
use_ada_layer_norm
,
zero_proj_out
=
False
,
)
else
:
self
.
audio_cross_attn
=
None
# Feed-forward
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
activation_fn
=
activation_fn
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
# Temp-Attn
assert
unet_use_temporal_attention
is
not
None
if
unet_use_temporal_attention
:
self
.
attn_temp
=
CrossAttention
(
query_dim
=
dim
,
heads
=
num_attention_heads
,
dim_head
=
attention_head_dim
,
dropout
=
dropout
,
bias
=
attention_bias
,
upcast_attention
=
upcast_attention
,
)
nn
.
init
.
zeros_
(
self
.
attn_temp
.
to_out
[
0
].
weight
.
data
)
self
.
norm_temp
=
AdaLayerNorm
(
dim
,
num_embeds_ada_norm
)
if
self
.
use_ada_layer_norm
else
nn
.
LayerNorm
(
dim
)
def
set_use_memory_efficient_attention_xformers
(
self
,
use_memory_efficient_attention_xformers
:
bool
):
if
not
is_xformers_available
():
print
(
"Here is how to install it"
)
raise
ModuleNotFoundError
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
,
name
=
"xformers"
,
)
elif
not
torch
.
cuda
.
is_available
():
raise
ValueError
(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else
:
try
:
# Make sure we can run the memory efficient attention
_
=
xformers
.
ops
.
memory_efficient_attention
(
torch
.
randn
((
1
,
2
,
40
),
device
=
"cuda"
),
torch
.
randn
((
1
,
2
,
40
),
device
=
"cuda"
),
torch
.
randn
((
1
,
2
,
40
),
device
=
"cuda"
),
)
except
Exception
as
e
:
raise
e
self
.
attn1
.
_use_memory_efficient_attention_xformers
=
use_memory_efficient_attention_xformers
if
self
.
audio_cross_attn
is
not
None
:
self
.
audio_cross_attn
.
attn
.
_use_memory_efficient_attention_xformers
=
(
use_memory_efficient_attention_xformers
)
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def
forward
(
self
,
hidden_states
,
encoder_hidden_states
=
None
,
timestep
=
None
,
attention_mask
=
None
,
video_length
=
None
):
# SparseCausal-Attention
norm_hidden_states
=
(
self
.
norm1
(
hidden_states
,
timestep
)
if
self
.
use_ada_layer_norm
else
self
.
norm1
(
hidden_states
)
)
# if self.only_cross_attention:
# hidden_states = (
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
# )
# else:
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
# pdb.set_trace()
if
self
.
unet_use_cross_frame_attention
:
hidden_states
=
(
self
.
attn1
(
norm_hidden_states
,
attention_mask
=
attention_mask
,
video_length
=
video_length
)
+
hidden_states
)
else
:
hidden_states
=
self
.
attn1
(
norm_hidden_states
,
attention_mask
=
attention_mask
)
+
hidden_states
if
self
.
audio_cross_attn
is
not
None
and
encoder_hidden_states
is
not
None
:
hidden_states
=
self
.
audio_cross_attn
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
attention_mask
=
attention_mask
)
# Feed-forward
hidden_states
=
self
.
ff
(
self
.
norm3
(
hidden_states
))
+
hidden_states
# Temporal-Attention
if
self
.
unet_use_temporal_attention
:
d
=
hidden_states
.
shape
[
1
]
hidden_states
=
rearrange
(
hidden_states
,
"(b f) d c -> (b d) f c"
,
f
=
video_length
)
norm_hidden_states
=
(
self
.
norm_temp
(
hidden_states
,
timestep
)
if
self
.
use_ada_layer_norm
else
self
.
norm_temp
(
hidden_states
)
)
hidden_states
=
self
.
attn_temp
(
norm_hidden_states
)
+
hidden_states
hidden_states
=
rearrange
(
hidden_states
,
"(b d) f c -> (b f) d c"
,
d
=
d
)
return
hidden_states
class
AudioTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
num_attention_heads
:
int
,
attention_head_dim
:
int
,
dropout
=
0.0
,
cross_attention_dim
:
Optional
[
int
]
=
None
,
activation_fn
:
str
=
"geglu"
,
num_embeds_ada_norm
:
Optional
[
int
]
=
None
,
attention_bias
:
bool
=
False
,
only_cross_attention
:
bool
=
False
,
upcast_attention
:
bool
=
False
,
use_motion_module
:
bool
=
False
,
unet_use_cross_frame_attention
=
None
,
unet_use_temporal_attention
=
None
,
add_audio_layer
=
False
,
):
super
().
__init__
()
self
.
only_cross_attention
=
only_cross_attention
self
.
use_ada_layer_norm
=
num_embeds_ada_norm
is
not
None
self
.
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
self
.
unet_use_temporal_attention
=
unet_use_temporal_attention
self
.
use_motion_module
=
use_motion_module
self
.
add_audio_layer
=
add_audio_layer
# SC-Attn
assert
unet_use_cross_frame_attention
is
not
None
if
unet_use_cross_frame_attention
:
raise
NotImplementedError
(
"SparseCausalAttention2D not implemented yet."
)
else
:
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
num_attention_heads
,
dim_head
=
attention_head_dim
,
dropout
=
dropout
,
bias
=
attention_bias
,
upcast_attention
=
upcast_attention
,
)
self
.
norm1
=
AdaLayerNorm
(
dim
,
num_embeds_ada_norm
)
if
self
.
use_ada_layer_norm
else
nn
.
LayerNorm
(
dim
)
self
.
audio_cross_attn
=
AudioCrossAttn
(
dim
=
dim
,
cross_attention_dim
=
cross_attention_dim
,
num_attention_heads
=
num_attention_heads
,
attention_head_dim
=
attention_head_dim
,
dropout
=
dropout
,
attention_bias
=
attention_bias
,
upcast_attention
=
upcast_attention
,
num_embeds_ada_norm
=
num_embeds_ada_norm
,
use_ada_layer_norm
=
self
.
use_ada_layer_norm
,
zero_proj_out
=
False
,
)
# Feed-forward
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
activation_fn
=
activation_fn
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
def
set_use_memory_efficient_attention_xformers
(
self
,
use_memory_efficient_attention_xformers
:
bool
):
if
not
is_xformers_available
():
print
(
"Here is how to install it"
)
raise
ModuleNotFoundError
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
,
name
=
"xformers"
,
)
elif
not
torch
.
cuda
.
is_available
():
raise
ValueError
(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else
:
try
:
# Make sure we can run the memory efficient attention
_
=
xformers
.
ops
.
memory_efficient_attention
(
torch
.
randn
((
1
,
2
,
40
),
device
=
"cuda"
),
torch
.
randn
((
1
,
2
,
40
),
device
=
"cuda"
),
torch
.
randn
((
1
,
2
,
40
),
device
=
"cuda"
),
)
except
Exception
as
e
:
raise
e
self
.
attn1
.
_use_memory_efficient_attention_xformers
=
use_memory_efficient_attention_xformers
if
self
.
audio_cross_attn
is
not
None
:
self
.
audio_cross_attn
.
attn
.
_use_memory_efficient_attention_xformers
=
(
use_memory_efficient_attention_xformers
)
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def
forward
(
self
,
hidden_states
,
encoder_hidden_states
=
None
,
timestep
=
None
,
attention_mask
=
None
,
video_length
=
None
):
# SparseCausal-Attention
norm_hidden_states
=
(
self
.
norm1
(
hidden_states
,
timestep
)
if
self
.
use_ada_layer_norm
else
self
.
norm1
(
hidden_states
)
)
# pdb.set_trace()
if
self
.
unet_use_cross_frame_attention
:
hidden_states
=
(
self
.
attn1
(
norm_hidden_states
,
attention_mask
=
attention_mask
,
video_length
=
video_length
)
+
hidden_states
)
else
:
hidden_states
=
self
.
attn1
(
norm_hidden_states
,
attention_mask
=
attention_mask
)
+
hidden_states
if
self
.
audio_cross_attn
is
not
None
and
encoder_hidden_states
is
not
None
:
hidden_states
=
self
.
audio_cross_attn
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
attention_mask
=
attention_mask
)
# Feed-forward
hidden_states
=
self
.
ff
(
self
.
norm3
(
hidden_states
))
+
hidden_states
return
hidden_states
class
AudioCrossAttn
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
cross_attention_dim
,
num_attention_heads
,
attention_head_dim
,
dropout
,
attention_bias
,
upcast_attention
,
num_embeds_ada_norm
,
use_ada_layer_norm
,
zero_proj_out
=
False
,
):
super
().
__init__
()
self
.
norm
=
AdaLayerNorm
(
dim
,
num_embeds_ada_norm
)
if
use_ada_layer_norm
else
nn
.
LayerNorm
(
dim
)
self
.
attn
=
CrossAttention
(
query_dim
=
dim
,
cross_attention_dim
=
cross_attention_dim
,
heads
=
num_attention_heads
,
dim_head
=
attention_head_dim
,
dropout
=
dropout
,
bias
=
attention_bias
,
upcast_attention
=
upcast_attention
,
)
if
zero_proj_out
:
self
.
proj_out
=
zero_module
(
nn
.
Linear
(
dim
,
dim
))
self
.
zero_proj_out
=
zero_proj_out
self
.
use_ada_layer_norm
=
use_ada_layer_norm
def
forward
(
self
,
hidden_states
,
encoder_hidden_states
=
None
,
timestep
=
None
,
attention_mask
=
None
):
previous_hidden_states
=
hidden_states
hidden_states
=
self
.
norm
(
hidden_states
,
timestep
)
if
self
.
use_ada_layer_norm
else
self
.
norm
(
hidden_states
)
if
encoder_hidden_states
.
dim
()
==
4
:
encoder_hidden_states
=
rearrange
(
encoder_hidden_states
,
"b f n d -> (b f) n d"
)
hidden_states
=
self
.
attn
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
attention_mask
=
attention_mask
)
if
self
.
zero_proj_out
:
hidden_states
=
self
.
proj_out
(
hidden_states
)
return
hidden_states
+
previous_hidden_states
LatentSync/latentsync/models/motion_module.py
0 → 100644
View file @
5c023842
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
# Actually we don't use the motion module in the final version of LatentSync
# When we started the project, we used the codebase of AnimateDiff and tried motion module
# But the results are poor, and we decied to leave the code here for possible future usage
from
dataclasses
import
dataclass
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.modeling_utils
import
ModelMixin
from
diffusers.utils
import
BaseOutput
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.models.attention
import
CrossAttention
,
FeedForward
from
einops
import
rearrange
,
repeat
import
math
from
.utils
import
zero_module
@
dataclass
class
TemporalTransformer3DModelOutput
(
BaseOutput
):
sample
:
torch
.
FloatTensor
if
is_xformers_available
():
import
xformers
import
xformers.ops
else
:
xformers
=
None
def
get_motion_module
(
in_channels
,
motion_module_type
:
str
,
motion_module_kwargs
:
dict
):
if
motion_module_type
==
"Vanilla"
:
return
VanillaTemporalModule
(
in_channels
=
in_channels
,
**
motion_module_kwargs
,
)
else
:
raise
ValueError
class
VanillaTemporalModule
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
num_attention_heads
=
8
,
num_transformer_block
=
2
,
attention_block_types
=
(
"Temporal_Self"
,
"Temporal_Self"
),
cross_frame_attention_mode
=
None
,
temporal_position_encoding
=
False
,
temporal_position_encoding_max_len
=
24
,
temporal_attention_dim_div
=
1
,
zero_initialize
=
True
,
):
super
().
__init__
()
self
.
temporal_transformer
=
TemporalTransformer3DModel
(
in_channels
=
in_channels
,
num_attention_heads
=
num_attention_heads
,
attention_head_dim
=
in_channels
//
num_attention_heads
//
temporal_attention_dim_div
,
num_layers
=
num_transformer_block
,
attention_block_types
=
attention_block_types
,
cross_frame_attention_mode
=
cross_frame_attention_mode
,
temporal_position_encoding
=
temporal_position_encoding
,
temporal_position_encoding_max_len
=
temporal_position_encoding_max_len
,
)
if
zero_initialize
:
self
.
temporal_transformer
.
proj_out
=
zero_module
(
self
.
temporal_transformer
.
proj_out
)
def
forward
(
self
,
input_tensor
,
temb
,
encoder_hidden_states
,
attention_mask
=
None
,
anchor_frame_idx
=
None
):
hidden_states
=
input_tensor
hidden_states
=
self
.
temporal_transformer
(
hidden_states
,
encoder_hidden_states
,
attention_mask
)
output
=
hidden_states
return
output
class
TemporalTransformer3DModel
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
num_attention_heads
,
attention_head_dim
,
num_layers
,
attention_block_types
=
(
"Temporal_Self"
,
"Temporal_Self"
,
),
dropout
=
0.0
,
norm_num_groups
=
32
,
cross_attention_dim
=
768
,
activation_fn
=
"geglu"
,
attention_bias
=
False
,
upcast_attention
=
False
,
cross_frame_attention_mode
=
None
,
temporal_position_encoding
=
False
,
temporal_position_encoding_max_len
=
24
,
):
super
().
__init__
()
inner_dim
=
num_attention_heads
*
attention_head_dim
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
norm_num_groups
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
self
.
proj_in
=
nn
.
Linear
(
in_channels
,
inner_dim
)
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
TemporalTransformerBlock
(
dim
=
inner_dim
,
num_attention_heads
=
num_attention_heads
,
attention_head_dim
=
attention_head_dim
,
attention_block_types
=
attention_block_types
,
dropout
=
dropout
,
norm_num_groups
=
norm_num_groups
,
cross_attention_dim
=
cross_attention_dim
,
activation_fn
=
activation_fn
,
attention_bias
=
attention_bias
,
upcast_attention
=
upcast_attention
,
cross_frame_attention_mode
=
cross_frame_attention_mode
,
temporal_position_encoding
=
temporal_position_encoding
,
temporal_position_encoding_max_len
=
temporal_position_encoding_max_len
,
)
for
d
in
range
(
num_layers
)
]
)
self
.
proj_out
=
nn
.
Linear
(
inner_dim
,
in_channels
)
def
forward
(
self
,
hidden_states
,
encoder_hidden_states
=
None
,
attention_mask
=
None
):
assert
hidden_states
.
dim
()
==
5
,
f
"Expected hidden_states to have ndim=5, but got ndim=
{
hidden_states
.
dim
()
}
."
video_length
=
hidden_states
.
shape
[
2
]
hidden_states
=
rearrange
(
hidden_states
,
"b c f h w -> (b f) c h w"
)
batch
,
channel
,
height
,
weight
=
hidden_states
.
shape
residual
=
hidden_states
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
batch
,
height
*
weight
,
channel
)
hidden_states
=
self
.
proj_in
(
hidden_states
)
# Transformer Blocks
for
block
in
self
.
transformer_blocks
:
hidden_states
=
block
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
video_length
=
video_length
)
# output
hidden_states
=
self
.
proj_out
(
hidden_states
)
hidden_states
=
hidden_states
.
reshape
(
batch
,
height
,
weight
,
channel
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
output
=
hidden_states
+
residual
output
=
rearrange
(
output
,
"(b f) c h w -> b c f h w"
,
f
=
video_length
)
return
output
class
TemporalTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_attention_heads
,
attention_head_dim
,
attention_block_types
=
(
"Temporal_Self"
,
"Temporal_Self"
,
),
dropout
=
0.0
,
norm_num_groups
=
32
,
cross_attention_dim
=
768
,
activation_fn
=
"geglu"
,
attention_bias
=
False
,
upcast_attention
=
False
,
cross_frame_attention_mode
=
None
,
temporal_position_encoding
=
False
,
temporal_position_encoding_max_len
=
24
,
):
super
().
__init__
()
attention_blocks
=
[]
norms
=
[]
for
block_name
in
attention_block_types
:
attention_blocks
.
append
(
VersatileAttention
(
attention_mode
=
block_name
.
split
(
"_"
)[
0
],
cross_attention_dim
=
cross_attention_dim
if
block_name
.
endswith
(
"_Cross"
)
else
None
,
query_dim
=
dim
,
heads
=
num_attention_heads
,
dim_head
=
attention_head_dim
,
dropout
=
dropout
,
bias
=
attention_bias
,
upcast_attention
=
upcast_attention
,
cross_frame_attention_mode
=
cross_frame_attention_mode
,
temporal_position_encoding
=
temporal_position_encoding
,
temporal_position_encoding_max_len
=
temporal_position_encoding_max_len
,
)
)
norms
.
append
(
nn
.
LayerNorm
(
dim
))
self
.
attention_blocks
=
nn
.
ModuleList
(
attention_blocks
)
self
.
norms
=
nn
.
ModuleList
(
norms
)
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
activation_fn
=
activation_fn
)
self
.
ff_norm
=
nn
.
LayerNorm
(
dim
)
def
forward
(
self
,
hidden_states
,
encoder_hidden_states
=
None
,
attention_mask
=
None
,
video_length
=
None
):
for
attention_block
,
norm
in
zip
(
self
.
attention_blocks
,
self
.
norms
):
norm_hidden_states
=
norm
(
hidden_states
)
hidden_states
=
(
attention_block
(
norm_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
if
attention_block
.
is_cross_attention
else
None
,
video_length
=
video_length
,
)
+
hidden_states
)
hidden_states
=
self
.
ff
(
self
.
ff_norm
(
hidden_states
))
+
hidden_states
output
=
hidden_states
return
output
class
PositionalEncoding
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
dropout
=
0.0
,
max_len
=
24
):
super
().
__init__
()
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout
)
position
=
torch
.
arange
(
max_len
).
unsqueeze
(
1
)
div_term
=
torch
.
exp
(
torch
.
arange
(
0
,
d_model
,
2
)
*
(
-
math
.
log
(
10000.0
)
/
d_model
))
pe
=
torch
.
zeros
(
1
,
max_len
,
d_model
)
pe
[
0
,
:,
0
::
2
]
=
torch
.
sin
(
position
*
div_term
)
pe
[
0
,
:,
1
::
2
]
=
torch
.
cos
(
position
*
div_term
)
self
.
register_buffer
(
"pe"
,
pe
)
def
forward
(
self
,
x
):
x
=
x
+
self
.
pe
[:,
:
x
.
size
(
1
)]
return
self
.
dropout
(
x
)
class
VersatileAttention
(
CrossAttention
):
def
__init__
(
self
,
attention_mode
=
None
,
cross_frame_attention_mode
=
None
,
temporal_position_encoding
=
False
,
temporal_position_encoding_max_len
=
24
,
*
args
,
**
kwargs
,
):
super
().
__init__
(
*
args
,
**
kwargs
)
assert
attention_mode
==
"Temporal"
self
.
attention_mode
=
attention_mode
self
.
is_cross_attention
=
kwargs
[
"cross_attention_dim"
]
is
not
None
self
.
pos_encoder
=
(
PositionalEncoding
(
kwargs
[
"query_dim"
],
dropout
=
0.0
,
max_len
=
temporal_position_encoding_max_len
)
if
(
temporal_position_encoding
and
attention_mode
==
"Temporal"
)
else
None
)
def
extra_repr
(
self
):
return
f
"(Module Info) Attention_Mode:
{
self
.
attention_mode
}
, Is_Cross_Attention:
{
self
.
is_cross_attention
}
"
def
forward
(
self
,
hidden_states
,
encoder_hidden_states
=
None
,
attention_mask
=
None
,
video_length
=
None
):
batch_size
,
sequence_length
,
_
=
hidden_states
.
shape
if
self
.
attention_mode
==
"Temporal"
:
d
=
hidden_states
.
shape
[
1
]
hidden_states
=
rearrange
(
hidden_states
,
"(b f) d c -> (b d) f c"
,
f
=
video_length
)
if
self
.
pos_encoder
is
not
None
:
hidden_states
=
self
.
pos_encoder
(
hidden_states
)
encoder_hidden_states
=
(
repeat
(
encoder_hidden_states
,
"b n c -> (b d) n c"
,
d
=
d
)
if
encoder_hidden_states
is
not
None
else
encoder_hidden_states
)
else
:
raise
NotImplementedError
# encoder_hidden_states = encoder_hidden_states
if
self
.
group_norm
is
not
None
:
hidden_states
=
self
.
group_norm
(
hidden_states
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
query
=
self
.
to_q
(
hidden_states
)
dim
=
query
.
shape
[
-
1
]
query
=
self
.
reshape_heads_to_batch_dim
(
query
)
if
self
.
added_kv_proj_dim
is
not
None
:
raise
NotImplementedError
encoder_hidden_states
=
encoder_hidden_states
if
encoder_hidden_states
is
not
None
else
hidden_states
key
=
self
.
to_k
(
encoder_hidden_states
)
value
=
self
.
to_v
(
encoder_hidden_states
)
key
=
self
.
reshape_heads_to_batch_dim
(
key
)
value
=
self
.
reshape_heads_to_batch_dim
(
value
)
if
attention_mask
is
not
None
:
if
attention_mask
.
shape
[
-
1
]
!=
query
.
shape
[
1
]:
target_length
=
query
.
shape
[
1
]
attention_mask
=
F
.
pad
(
attention_mask
,
(
0
,
target_length
),
value
=
0.0
)
attention_mask
=
attention_mask
.
repeat_interleave
(
self
.
heads
,
dim
=
0
)
# attention, what we cannot get enough of
if
self
.
_use_memory_efficient_attention_xformers
:
hidden_states
=
self
.
_memory_efficient_attention_xformers
(
query
,
key
,
value
,
attention_mask
)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states
=
hidden_states
.
to
(
query
.
dtype
)
else
:
if
self
.
_slice_size
is
None
or
query
.
shape
[
0
]
//
self
.
_slice_size
==
1
:
hidden_states
=
self
.
_attention
(
query
,
key
,
value
,
attention_mask
)
else
:
hidden_states
=
self
.
_sliced_attention
(
query
,
key
,
value
,
sequence_length
,
dim
,
attention_mask
)
# linear proj
hidden_states
=
self
.
to_out
[
0
](
hidden_states
)
# dropout
hidden_states
=
self
.
to_out
[
1
](
hidden_states
)
if
self
.
attention_mode
==
"Temporal"
:
hidden_states
=
rearrange
(
hidden_states
,
"(b d) f c -> (b f) d c"
,
d
=
d
)
return
hidden_states
LatentSync/latentsync/models/resnet.py
0 → 100644
View file @
5c023842
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
class
InflatedConv3d
(
nn
.
Conv2d
):
def
forward
(
self
,
x
):
video_length
=
x
.
shape
[
2
]
x
=
rearrange
(
x
,
"b c f h w -> (b f) c h w"
)
x
=
super
().
forward
(
x
)
x
=
rearrange
(
x
,
"(b f) c h w -> b c f h w"
,
f
=
video_length
)
return
x
class
InflatedGroupNorm
(
nn
.
GroupNorm
):
def
forward
(
self
,
x
):
video_length
=
x
.
shape
[
2
]
x
=
rearrange
(
x
,
"b c f h w -> (b f) c h w"
)
x
=
super
().
forward
(
x
)
x
=
rearrange
(
x
,
"(b f) c h w -> b c f h w"
,
f
=
video_length
)
return
x
class
Upsample3D
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
out_channels
=
None
,
name
=
"conv"
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_conv_transpose
=
use_conv_transpose
self
.
name
=
name
conv
=
None
if
use_conv_transpose
:
raise
NotImplementedError
elif
use_conv
:
self
.
conv
=
InflatedConv3d
(
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
def
forward
(
self
,
hidden_states
,
output_size
=
None
):
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv_transpose
:
raise
NotImplementedError
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype
=
hidden_states
.
dtype
if
dtype
==
torch
.
bfloat16
:
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if
hidden_states
.
shape
[
0
]
>=
64
:
hidden_states
=
hidden_states
.
contiguous
()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if
output_size
is
None
:
hidden_states
=
F
.
interpolate
(
hidden_states
,
scale_factor
=
[
1.0
,
2.0
,
2.0
],
mode
=
"nearest"
)
else
:
hidden_states
=
F
.
interpolate
(
hidden_states
,
size
=
output_size
,
mode
=
"nearest"
)
# If the input is bfloat16, we cast back to bfloat16
if
dtype
==
torch
.
bfloat16
:
hidden_states
=
hidden_states
.
to
(
dtype
)
# if self.use_conv:
# if self.name == "conv":
# hidden_states = self.conv(hidden_states)
# else:
# hidden_states = self.Conv2d_0(hidden_states)
hidden_states
=
self
.
conv
(
hidden_states
)
return
hidden_states
class
Downsample3D
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
out_channels
=
None
,
padding
=
1
,
name
=
"conv"
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
padding
=
padding
stride
=
2
self
.
name
=
name
if
use_conv
:
self
.
conv
=
InflatedConv3d
(
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
else
:
raise
NotImplementedError
def
forward
(
self
,
hidden_states
):
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv
and
self
.
padding
==
0
:
raise
NotImplementedError
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
hidden_states
=
self
.
conv
(
hidden_states
)
return
hidden_states
class
ResnetBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
=
0.0
,
temb_channels
=
512
,
groups
=
32
,
groups_out
=
None
,
pre_norm
=
True
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
time_embedding_norm
=
"default"
,
output_scale_factor
=
1.0
,
use_in_shortcut
=
None
,
use_inflated_groupnorm
=
False
,
):
super
().
__init__
()
self
.
pre_norm
=
pre_norm
self
.
pre_norm
=
True
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
time_embedding_norm
=
time_embedding_norm
self
.
output_scale_factor
=
output_scale_factor
if
groups_out
is
None
:
groups_out
=
groups
assert
use_inflated_groupnorm
!=
None
if
use_inflated_groupnorm
:
self
.
norm1
=
InflatedGroupNorm
(
num_groups
=
groups
,
num_channels
=
in_channels
,
eps
=
eps
,
affine
=
True
)
else
:
self
.
norm1
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups
,
num_channels
=
in_channels
,
eps
=
eps
,
affine
=
True
)
self
.
conv1
=
InflatedConv3d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
is
not
None
:
time_emb_proj_out_channels
=
out_channels
# if self.time_embedding_norm == "default":
# time_emb_proj_out_channels = out_channels
# elif self.time_embedding_norm == "scale_shift":
# time_emb_proj_out_channels = out_channels * 2
# else:
# raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
self
.
time_emb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
time_emb_proj_out_channels
)
else
:
self
.
time_emb_proj
=
None
if
self
.
time_embedding_norm
==
"scale_shift"
:
self
.
double_len_linear
=
torch
.
nn
.
Linear
(
time_emb_proj_out_channels
,
2
*
time_emb_proj_out_channels
)
else
:
self
.
double_len_linear
=
None
if
use_inflated_groupnorm
:
self
.
norm2
=
InflatedGroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
else
:
self
.
norm2
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
InflatedConv3d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
self
.
use_in_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
use_in_shortcut
is
None
else
use_in_shortcut
self
.
conv_shortcut
=
None
if
self
.
use_in_shortcut
:
self
.
conv_shortcut
=
InflatedConv3d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
input_tensor
,
temb
):
hidden_states
=
input_tensor
hidden_states
=
self
.
norm1
(
hidden_states
)
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
hidden_states
=
self
.
conv1
(
hidden_states
)
if
temb
is
not
None
:
if
temb
.
dim
()
==
2
:
# input (1, 1280)
temb
=
self
.
time_emb_proj
(
self
.
nonlinearity
(
temb
))
temb
=
temb
[:,
:,
None
,
None
,
None
]
# unsqueeze
else
:
# input (1, 1280, 16)
temb
=
temb
.
permute
(
0
,
2
,
1
)
temb
=
self
.
time_emb_proj
(
self
.
nonlinearity
(
temb
))
if
self
.
double_len_linear
is
not
None
:
temb
=
self
.
double_len_linear
(
self
.
nonlinearity
(
temb
))
temb
=
temb
.
permute
(
0
,
2
,
1
)
temb
=
temb
[:,
:,
:,
None
,
None
]
if
temb
is
not
None
and
self
.
time_embedding_norm
==
"default"
:
hidden_states
=
hidden_states
+
temb
hidden_states
=
self
.
norm2
(
hidden_states
)
if
temb
is
not
None
and
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
hidden_states
=
hidden_states
*
(
1
+
scale
)
+
shift
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
conv2
(
hidden_states
)
if
self
.
conv_shortcut
is
not
None
:
input_tensor
=
self
.
conv_shortcut
(
input_tensor
)
output_tensor
=
(
input_tensor
+
hidden_states
)
/
self
.
output_scale_factor
return
output_tensor
class
Mish
(
torch
.
nn
.
Module
):
def
forward
(
self
,
hidden_states
):
return
hidden_states
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
hidden_states
))
LatentSync/latentsync/models/syncnet.py
0 → 100644
View file @
5c023842
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
torch
import
nn
from
einops
import
rearrange
from
torch.nn
import
functional
as
F
from
..utils.util
import
cosine_loss
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
diffusers.models.attention
import
CrossAttention
,
FeedForward
from
diffusers.utils.import_utils
import
is_xformers_available
from
einops
import
rearrange
class
SyncNet
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
audio_encoder
=
DownEncoder2D
(
in_channels
=
config
[
"audio_encoder"
][
"in_channels"
],
block_out_channels
=
config
[
"audio_encoder"
][
"block_out_channels"
],
downsample_factors
=
config
[
"audio_encoder"
][
"downsample_factors"
],
dropout
=
config
[
"audio_encoder"
][
"dropout"
],
attn_blocks
=
config
[
"audio_encoder"
][
"attn_blocks"
],
)
self
.
visual_encoder
=
DownEncoder2D
(
in_channels
=
config
[
"visual_encoder"
][
"in_channels"
],
block_out_channels
=
config
[
"visual_encoder"
][
"block_out_channels"
],
downsample_factors
=
config
[
"visual_encoder"
][
"downsample_factors"
],
dropout
=
config
[
"visual_encoder"
][
"dropout"
],
attn_blocks
=
config
[
"visual_encoder"
][
"attn_blocks"
],
)
self
.
eval
()
def
forward
(
self
,
image_sequences
,
audio_sequences
):
vision_embeds
=
self
.
visual_encoder
(
image_sequences
)
# (b, c, 1, 1)
audio_embeds
=
self
.
audio_encoder
(
audio_sequences
)
# (b, c, 1, 1)
vision_embeds
=
vision_embeds
.
reshape
(
vision_embeds
.
shape
[
0
],
-
1
)
# (b, c)
audio_embeds
=
audio_embeds
.
reshape
(
audio_embeds
.
shape
[
0
],
-
1
)
# (b, c)
# Make them unit vectors
vision_embeds
=
F
.
normalize
(
vision_embeds
,
p
=
2
,
dim
=
1
)
audio_embeds
=
F
.
normalize
(
audio_embeds
,
p
=
2
,
dim
=
1
)
return
vision_embeds
,
audio_embeds
class
ResnetBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
dropout
:
float
=
0.0
,
norm_num_groups
:
int
=
32
,
eps
:
float
=
1e-6
,
act_fn
:
str
=
"silu"
,
downsample_factor
=
2
,
):
super
().
__init__
()
self
.
norm1
=
nn
.
GroupNorm
(
num_groups
=
norm_num_groups
,
num_channels
=
in_channels
,
eps
=
eps
,
affine
=
True
)
self
.
conv1
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
norm2
=
nn
.
GroupNorm
(
num_groups
=
norm_num_groups
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
conv2
=
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
act_fn
==
"relu"
:
self
.
act_fn
=
nn
.
ReLU
()
elif
act_fn
==
"silu"
:
self
.
act_fn
=
nn
.
SiLU
()
if
in_channels
!=
out_channels
:
self
.
conv_shortcut
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
else
:
self
.
conv_shortcut
=
None
if
isinstance
(
downsample_factor
,
list
):
downsample_factor
=
tuple
(
downsample_factor
)
if
downsample_factor
==
1
:
self
.
downsample_conv
=
None
else
:
self
.
downsample_conv
=
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
downsample_factor
,
padding
=
0
)
self
.
pad
=
(
0
,
1
,
0
,
1
)
if
isinstance
(
downsample_factor
,
tuple
):
if
downsample_factor
[
0
]
==
1
:
self
.
pad
=
(
0
,
1
,
1
,
1
)
# The padding order is from back to front
elif
downsample_factor
[
1
]
==
1
:
self
.
pad
=
(
1
,
1
,
0
,
1
)
def
forward
(
self
,
input_tensor
):
hidden_states
=
input_tensor
hidden_states
=
self
.
norm1
(
hidden_states
)
hidden_states
=
self
.
act_fn
(
hidden_states
)
hidden_states
=
self
.
conv1
(
hidden_states
)
hidden_states
=
self
.
norm2
(
hidden_states
)
hidden_states
=
self
.
act_fn
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
conv2
(
hidden_states
)
if
self
.
conv_shortcut
is
not
None
:
input_tensor
=
self
.
conv_shortcut
(
input_tensor
)
hidden_states
+=
input_tensor
if
self
.
downsample_conv
is
not
None
:
hidden_states
=
F
.
pad
(
hidden_states
,
self
.
pad
,
mode
=
"constant"
,
value
=
0
)
hidden_states
=
self
.
downsample_conv
(
hidden_states
)
return
hidden_states
class
AttentionBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
norm_num_groups
=
32
,
dropout
=
0.0
):
super
().
__init__
()
if
not
is_xformers_available
():
raise
ModuleNotFoundError
(
"You have to install xformers to enable memory efficient attetion"
,
name
=
"xformers"
)
# inner_dim = dim_head * heads
self
.
norm1
=
torch
.
nn
.
GroupNorm
(
num_groups
=
norm_num_groups
,
num_channels
=
query_dim
,
eps
=
1e-6
,
affine
=
True
)
self
.
norm2
=
nn
.
LayerNorm
(
query_dim
)
self
.
norm3
=
nn
.
LayerNorm
(
query_dim
)
self
.
ff
=
FeedForward
(
query_dim
,
dropout
=
dropout
,
activation_fn
=
"geglu"
)
self
.
conv_in
=
nn
.
Conv2d
(
query_dim
,
query_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
conv_out
=
nn
.
Conv2d
(
query_dim
,
query_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
attn
=
CrossAttention
(
query_dim
=
query_dim
,
heads
=
8
,
dim_head
=
query_dim
//
8
,
dropout
=
dropout
,
bias
=
True
)
self
.
attn
.
_use_memory_efficient_attention_xformers
=
True
def
forward
(
self
,
hidden_states
):
assert
hidden_states
.
dim
()
==
4
,
f
"Expected hidden_states to have ndim=4, but got ndim=
{
hidden_states
.
dim
()
}
."
batch
,
channel
,
height
,
width
=
hidden_states
.
shape
residual
=
hidden_states
hidden_states
=
self
.
norm1
(
hidden_states
)
hidden_states
=
self
.
conv_in
(
hidden_states
)
hidden_states
=
rearrange
(
hidden_states
,
"b c h w -> b (h w) c"
)
norm_hidden_states
=
self
.
norm2
(
hidden_states
)
hidden_states
=
self
.
attn
(
norm_hidden_states
,
attention_mask
=
None
)
+
hidden_states
hidden_states
=
self
.
ff
(
self
.
norm3
(
hidden_states
))
+
hidden_states
hidden_states
=
rearrange
(
hidden_states
,
"b (h w) c -> b c h w"
,
h
=
height
,
w
=
width
)
hidden_states
=
self
.
conv_out
(
hidden_states
)
hidden_states
=
hidden_states
+
residual
return
hidden_states
class
DownEncoder2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
=
4
*
16
,
block_out_channels
=
[
64
,
128
,
256
,
256
],
downsample_factors
=
[
2
,
2
,
2
,
2
],
layers_per_block
=
2
,
norm_num_groups
=
32
,
attn_blocks
=
[
1
,
1
,
1
,
1
],
dropout
:
float
=
0.0
,
act_fn
=
"silu"
,
):
super
().
__init__
()
self
.
layers_per_block
=
layers_per_block
# in
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
block_out_channels
[
0
],
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# down
self
.
down_blocks
=
nn
.
ModuleList
([])
output_channels
=
block_out_channels
[
0
]
for
i
,
block_out_channel
in
enumerate
(
block_out_channels
):
input_channels
=
output_channels
output_channels
=
block_out_channel
# is_final_block = i == len(block_out_channels) - 1
down_block
=
ResnetBlock2D
(
in_channels
=
input_channels
,
out_channels
=
output_channels
,
downsample_factor
=
downsample_factors
[
i
],
norm_num_groups
=
norm_num_groups
,
dropout
=
dropout
,
act_fn
=
act_fn
,
)
self
.
down_blocks
.
append
(
down_block
)
if
attn_blocks
[
i
]
==
1
:
attention_block
=
AttentionBlock2D
(
query_dim
=
output_channels
,
dropout
=
dropout
)
self
.
down_blocks
.
append
(
attention_block
)
# out
self
.
norm_out
=
nn
.
GroupNorm
(
num_channels
=
block_out_channels
[
-
1
],
num_groups
=
norm_num_groups
,
eps
=
1e-6
)
self
.
act_fn_out
=
nn
.
ReLU
()
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
conv_in
(
hidden_states
)
# down
for
down_block
in
self
.
down_blocks
:
hidden_states
=
down_block
(
hidden_states
)
# post-process
hidden_states
=
self
.
norm_out
(
hidden_states
)
hidden_states
=
self
.
act_fn_out
(
hidden_states
)
return
hidden_states
LatentSync/latentsync/models/syncnet_wav2lip.py
0 → 100644
View file @
5c023842
# Adapted from https://github.com/primepake/wav2lip_288x288/blob/master/models/syncnetv2.py
# The code here is for ablation study.
from
torch
import
nn
from
torch.nn
import
functional
as
F
class
SyncNetWav2Lip
(
nn
.
Module
):
def
__init__
(
self
,
act_fn
=
"leaky"
):
super
().
__init__
()
# input image sequences: (15, 128, 256)
self
.
visual_encoder
=
nn
.
Sequential
(
Conv2d
(
15
,
32
,
kernel_size
=
(
7
,
7
),
stride
=
1
,
padding
=
3
,
act_fn
=
act_fn
),
# (128, 256)
Conv2d
(
32
,
64
,
kernel_size
=
5
,
stride
=
(
1
,
2
),
padding
=
1
,
act_fn
=
act_fn
),
# (126, 127)
Conv2d
(
64
,
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
64
,
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
64
,
128
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
act_fn
=
act_fn
),
# (63, 64)
Conv2d
(
128
,
128
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
128
,
128
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
128
,
128
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
128
,
256
,
kernel_size
=
3
,
stride
=
3
,
padding
=
1
,
act_fn
=
act_fn
),
# (21, 22)
Conv2d
(
256
,
256
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
256
,
256
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
256
,
512
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
act_fn
=
act_fn
),
# (11, 11)
Conv2d
(
512
,
512
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
512
,
512
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
512
,
1024
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
act_fn
=
act_fn
),
# (6, 6)
Conv2d
(
1024
,
1024
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
1024
,
1024
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
1024
,
1024
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
act_fn
=
"relu"
),
# (3, 3)
Conv2d
(
1024
,
1024
,
kernel_size
=
3
,
stride
=
1
,
padding
=
0
,
act_fn
=
"relu"
),
# (1, 1)
Conv2d
(
1024
,
1024
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act_fn
=
"relu"
),
)
# input audio sequences: (1, 80, 16)
self
.
audio_encoder
=
nn
.
Sequential
(
Conv2d
(
1
,
32
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
act_fn
=
act_fn
),
Conv2d
(
32
,
32
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
32
,
32
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
32
,
64
,
kernel_size
=
3
,
stride
=
(
3
,
1
),
padding
=
1
,
act_fn
=
act_fn
),
# (27, 16)
Conv2d
(
64
,
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
64
,
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
64
,
128
,
kernel_size
=
3
,
stride
=
3
,
padding
=
1
,
act_fn
=
act_fn
),
# (9, 6)
Conv2d
(
128
,
128
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
128
,
128
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
128
,
256
,
kernel_size
=
3
,
stride
=
(
3
,
2
),
padding
=
1
,
act_fn
=
act_fn
),
# (3, 3)
Conv2d
(
256
,
256
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
256
,
256
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
256
,
512
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
act_fn
=
act_fn
),
Conv2d
(
512
,
512
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
512
,
512
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
residual
=
True
,
act_fn
=
act_fn
),
Conv2d
(
512
,
1024
,
kernel_size
=
3
,
stride
=
1
,
padding
=
0
,
act_fn
=
"relu"
),
# (1, 1)
Conv2d
(
1024
,
1024
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act_fn
=
"relu"
),
)
def
forward
(
self
,
image_sequences
,
audio_sequences
):
vision_embeds
=
self
.
visual_encoder
(
image_sequences
)
# (b, c, 1, 1)
audio_embeds
=
self
.
audio_encoder
(
audio_sequences
)
# (b, c, 1, 1)
vision_embeds
=
vision_embeds
.
reshape
(
vision_embeds
.
shape
[
0
],
-
1
)
# (b, c)
audio_embeds
=
audio_embeds
.
reshape
(
audio_embeds
.
shape
[
0
],
-
1
)
# (b, c)
# Make them unit vectors
vision_embeds
=
F
.
normalize
(
vision_embeds
,
p
=
2
,
dim
=
1
)
audio_embeds
=
F
.
normalize
(
audio_embeds
,
p
=
2
,
dim
=
1
)
return
vision_embeds
,
audio_embeds
class
Conv2d
(
nn
.
Module
):
def
__init__
(
self
,
cin
,
cout
,
kernel_size
,
stride
,
padding
,
residual
=
False
,
act_fn
=
"relu"
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
conv_block
=
nn
.
Sequential
(
nn
.
Conv2d
(
cin
,
cout
,
kernel_size
,
stride
,
padding
),
nn
.
BatchNorm2d
(
cout
))
if
act_fn
==
"relu"
:
self
.
act_fn
=
nn
.
ReLU
()
elif
act_fn
==
"tanh"
:
self
.
act_fn
=
nn
.
Tanh
()
elif
act_fn
==
"silu"
:
self
.
act_fn
=
nn
.
SiLU
()
elif
act_fn
==
"leaky"
:
self
.
act_fn
=
nn
.
LeakyReLU
(
0.2
,
inplace
=
True
)
self
.
residual
=
residual
def
forward
(
self
,
x
):
out
=
self
.
conv_block
(
x
)
if
self
.
residual
:
out
+=
x
return
self
.
act_fn
(
out
)
LatentSync/latentsync/models/unet.py
0 → 100644
View file @
5c023842
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet.py
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
copy
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.modeling_utils
import
ModelMixin
from
diffusers
import
UNet2DConditionModel
from
diffusers.utils
import
BaseOutput
,
logging
from
diffusers.models.embeddings
import
TimestepEmbedding
,
Timesteps
from
.unet_blocks
import
(
CrossAttnDownBlock3D
,
CrossAttnUpBlock3D
,
DownBlock3D
,
UNetMidBlock3DCrossAttn
,
UpBlock3D
,
get_down_block
,
get_up_block
,
)
from
.resnet
import
InflatedConv3d
,
InflatedGroupNorm
from
..utils.util
import
zero_rank_log
from
einops
import
rearrange
from
.utils
import
zero_module
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
@
dataclass
class
UNet3DConditionOutput
(
BaseOutput
):
sample
:
torch
.
FloatTensor
class
UNet3DConditionModel
(
ModelMixin
,
ConfigMixin
):
_supports_gradient_checkpointing
=
True
@
register_to_config
def
__init__
(
self
,
sample_size
:
Optional
[
int
]
=
None
,
in_channels
:
int
=
4
,
out_channels
:
int
=
4
,
center_input_sample
:
bool
=
False
,
flip_sin_to_cos
:
bool
=
True
,
freq_shift
:
int
=
0
,
down_block_types
:
Tuple
[
str
]
=
(
"CrossAttnDownBlock3D"
,
"CrossAttnDownBlock3D"
,
"CrossAttnDownBlock3D"
,
"DownBlock3D"
,
),
mid_block_type
:
str
=
"UNetMidBlock3DCrossAttn"
,
up_block_types
:
Tuple
[
str
]
=
(
"UpBlock3D"
,
"CrossAttnUpBlock3D"
,
"CrossAttnUpBlock3D"
,
"CrossAttnUpBlock3D"
),
only_cross_attention
:
Union
[
bool
,
Tuple
[
bool
]]
=
False
,
block_out_channels
:
Tuple
[
int
]
=
(
320
,
640
,
1280
,
1280
),
layers_per_block
:
int
=
2
,
downsample_padding
:
int
=
1
,
mid_block_scale_factor
:
float
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-5
,
cross_attention_dim
:
int
=
1280
,
attention_head_dim
:
Union
[
int
,
Tuple
[
int
]]
=
8
,
dual_cross_attention
:
bool
=
False
,
use_linear_projection
:
bool
=
False
,
class_embed_type
:
Optional
[
str
]
=
None
,
num_class_embeds
:
Optional
[
int
]
=
None
,
upcast_attention
:
bool
=
False
,
resnet_time_scale_shift
:
str
=
"default"
,
use_inflated_groupnorm
=
False
,
# Additional
use_motion_module
=
False
,
motion_module_resolutions
=
(
1
,
2
,
4
,
8
),
motion_module_mid_block
=
False
,
motion_module_decoder_only
=
False
,
motion_module_type
=
None
,
motion_module_kwargs
=
{},
unet_use_cross_frame_attention
=
False
,
unet_use_temporal_attention
=
False
,
add_audio_layer
=
False
,
audio_condition_method
:
str
=
"cross_attn"
,
custom_audio_layer
=
False
,
):
super
().
__init__
()
self
.
sample_size
=
sample_size
time_embed_dim
=
block_out_channels
[
0
]
*
4
self
.
use_motion_module
=
use_motion_module
self
.
add_audio_layer
=
add_audio_layer
self
.
conv_in
=
zero_module
(
InflatedConv3d
(
in_channels
,
block_out_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
)))
# time
self
.
time_proj
=
Timesteps
(
block_out_channels
[
0
],
flip_sin_to_cos
,
freq_shift
)
timestep_input_dim
=
block_out_channels
[
0
]
self
.
time_embedding
=
TimestepEmbedding
(
timestep_input_dim
,
time_embed_dim
)
# class embedding
if
class_embed_type
is
None
and
num_class_embeds
is
not
None
:
self
.
class_embedding
=
nn
.
Embedding
(
num_class_embeds
,
time_embed_dim
)
elif
class_embed_type
==
"timestep"
:
self
.
class_embedding
=
TimestepEmbedding
(
timestep_input_dim
,
time_embed_dim
)
elif
class_embed_type
==
"identity"
:
self
.
class_embedding
=
nn
.
Identity
(
time_embed_dim
,
time_embed_dim
)
else
:
self
.
class_embedding
=
None
self
.
down_blocks
=
nn
.
ModuleList
([])
self
.
mid_block
=
None
self
.
up_blocks
=
nn
.
ModuleList
([])
if
isinstance
(
only_cross_attention
,
bool
):
only_cross_attention
=
[
only_cross_attention
]
*
len
(
down_block_types
)
if
isinstance
(
attention_head_dim
,
int
):
attention_head_dim
=
(
attention_head_dim
,)
*
len
(
down_block_types
)
# down
output_channel
=
block_out_channels
[
0
]
for
i
,
down_block_type
in
enumerate
(
down_block_types
):
res
=
2
**
i
input_channel
=
output_channel
output_channel
=
block_out_channels
[
i
]
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
down_block
=
get_down_block
(
down_block_type
,
num_layers
=
layers_per_block
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
temb_channels
=
time_embed_dim
,
add_downsample
=
not
is_final_block
,
resnet_eps
=
norm_eps
,
resnet_act_fn
=
act_fn
,
resnet_groups
=
norm_num_groups
,
cross_attention_dim
=
cross_attention_dim
,
attn_num_head_channels
=
attention_head_dim
[
i
],
downsample_padding
=
downsample_padding
,
dual_cross_attention
=
dual_cross_attention
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
[
i
],
upcast_attention
=
upcast_attention
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
,
unet_use_temporal_attention
=
unet_use_temporal_attention
,
use_inflated_groupnorm
=
use_inflated_groupnorm
,
use_motion_module
=
use_motion_module
and
(
res
in
motion_module_resolutions
)
and
(
not
motion_module_decoder_only
),
motion_module_type
=
motion_module_type
,
motion_module_kwargs
=
motion_module_kwargs
,
add_audio_layer
=
add_audio_layer
,
audio_condition_method
=
audio_condition_method
,
custom_audio_layer
=
custom_audio_layer
,
)
self
.
down_blocks
.
append
(
down_block
)
# mid
if
mid_block_type
==
"UNetMidBlock3DCrossAttn"
:
self
.
mid_block
=
UNetMidBlock3DCrossAttn
(
in_channels
=
block_out_channels
[
-
1
],
temb_channels
=
time_embed_dim
,
resnet_eps
=
norm_eps
,
resnet_act_fn
=
act_fn
,
output_scale_factor
=
mid_block_scale_factor
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
cross_attention_dim
=
cross_attention_dim
,
attn_num_head_channels
=
attention_head_dim
[
-
1
],
resnet_groups
=
norm_num_groups
,
dual_cross_attention
=
dual_cross_attention
,
use_linear_projection
=
use_linear_projection
,
upcast_attention
=
upcast_attention
,
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
,
unet_use_temporal_attention
=
unet_use_temporal_attention
,
use_inflated_groupnorm
=
use_inflated_groupnorm
,
use_motion_module
=
use_motion_module
and
motion_module_mid_block
,
motion_module_type
=
motion_module_type
,
motion_module_kwargs
=
motion_module_kwargs
,
add_audio_layer
=
add_audio_layer
,
audio_condition_method
=
audio_condition_method
,
custom_audio_layer
=
custom_audio_layer
,
)
else
:
raise
ValueError
(
f
"unknown mid_block_type :
{
mid_block_type
}
"
)
# count how many layers upsample the videos
self
.
num_upsamplers
=
0
# up
reversed_block_out_channels
=
list
(
reversed
(
block_out_channels
))
reversed_attention_head_dim
=
list
(
reversed
(
attention_head_dim
))
only_cross_attention
=
list
(
reversed
(
only_cross_attention
))
output_channel
=
reversed_block_out_channels
[
0
]
for
i
,
up_block_type
in
enumerate
(
up_block_types
):
res
=
2
**
(
3
-
i
)
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
prev_output_channel
=
output_channel
output_channel
=
reversed_block_out_channels
[
i
]
input_channel
=
reversed_block_out_channels
[
min
(
i
+
1
,
len
(
block_out_channels
)
-
1
)]
# add upsample block for all BUT final layer
if
not
is_final_block
:
add_upsample
=
True
self
.
num_upsamplers
+=
1
else
:
add_upsample
=
False
up_block
=
get_up_block
(
up_block_type
,
num_layers
=
layers_per_block
+
1
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
time_embed_dim
,
add_upsample
=
add_upsample
,
resnet_eps
=
norm_eps
,
resnet_act_fn
=
act_fn
,
resnet_groups
=
norm_num_groups
,
cross_attention_dim
=
cross_attention_dim
,
attn_num_head_channels
=
reversed_attention_head_dim
[
i
],
dual_cross_attention
=
dual_cross_attention
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
[
i
],
upcast_attention
=
upcast_attention
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
,
unet_use_temporal_attention
=
unet_use_temporal_attention
,
use_inflated_groupnorm
=
use_inflated_groupnorm
,
use_motion_module
=
use_motion_module
and
(
res
in
motion_module_resolutions
),
motion_module_type
=
motion_module_type
,
motion_module_kwargs
=
motion_module_kwargs
,
add_audio_layer
=
add_audio_layer
,
audio_condition_method
=
audio_condition_method
,
custom_audio_layer
=
custom_audio_layer
,
)
self
.
up_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
# out
if
use_inflated_groupnorm
:
self
.
conv_norm_out
=
InflatedGroupNorm
(
num_channels
=
block_out_channels
[
0
],
num_groups
=
norm_num_groups
,
eps
=
norm_eps
)
else
:
self
.
conv_norm_out
=
nn
.
GroupNorm
(
num_channels
=
block_out_channels
[
0
],
num_groups
=
norm_num_groups
,
eps
=
norm_eps
)
self
.
conv_act
=
nn
.
SiLU
()
self
.
conv_out
=
zero_module
(
InflatedConv3d
(
block_out_channels
[
0
],
out_channels
,
kernel_size
=
3
,
padding
=
1
))
def
set_attention_slice
(
self
,
slice_size
):
r
"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims
=
[]
def
fn_recursive_retrieve_slicable_dims
(
module
:
torch
.
nn
.
Module
):
if
hasattr
(
module
,
"set_attention_slice"
):
sliceable_head_dims
.
append
(
module
.
sliceable_head_dim
)
for
child
in
module
.
children
():
fn_recursive_retrieve_slicable_dims
(
child
)
# retrieve number of attention layers
for
module
in
self
.
children
():
fn_recursive_retrieve_slicable_dims
(
module
)
num_slicable_layers
=
len
(
sliceable_head_dims
)
if
slice_size
==
"auto"
:
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size
=
[
dim
//
2
for
dim
in
sliceable_head_dims
]
elif
slice_size
==
"max"
:
# make smallest slice possible
slice_size
=
num_slicable_layers
*
[
1
]
slice_size
=
num_slicable_layers
*
[
slice_size
]
if
not
isinstance
(
slice_size
,
list
)
else
slice_size
if
len
(
slice_size
)
!=
len
(
sliceable_head_dims
):
raise
ValueError
(
f
"You have provided
{
len
(
slice_size
)
}
, but
{
self
.
config
}
has
{
len
(
sliceable_head_dims
)
}
different"
f
" attention layers. Make sure to match `len(slice_size)` to be
{
len
(
sliceable_head_dims
)
}
."
)
for
i
in
range
(
len
(
slice_size
)):
size
=
slice_size
[
i
]
dim
=
sliceable_head_dims
[
i
]
if
size
is
not
None
and
size
>
dim
:
raise
ValueError
(
f
"size
{
size
}
has to be smaller or equal to
{
dim
}
."
)
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def
fn_recursive_set_attention_slice
(
module
:
torch
.
nn
.
Module
,
slice_size
:
List
[
int
]):
if
hasattr
(
module
,
"set_attention_slice"
):
module
.
set_attention_slice
(
slice_size
.
pop
())
for
child
in
module
.
children
():
fn_recursive_set_attention_slice
(
child
,
slice_size
)
reversed_slice_size
=
list
(
reversed
(
slice_size
))
for
module
in
self
.
children
():
fn_recursive_set_attention_slice
(
module
,
reversed_slice_size
)
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
(
CrossAttnDownBlock3D
,
DownBlock3D
,
CrossAttnUpBlock3D
,
UpBlock3D
)):
module
.
gradient_checkpointing
=
value
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
torch
.
Tensor
,
float
,
int
],
encoder_hidden_states
:
torch
.
Tensor
,
class_labels
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
# support controlnet
down_block_additional_residuals
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
mid_block_additional_residual
:
Optional
[
torch
.
Tensor
]
=
None
,
return_dict
:
bool
=
True
,
)
->
Union
[
UNet3DConditionOutput
,
Tuple
]:
r
"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor
=
2
**
self
.
num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size
=
False
upsample_size
=
None
if
any
(
s
%
default_overall_up_factor
!=
0
for
s
in
sample
.
shape
[
-
2
:]):
logger
.
info
(
"Forward upsample size to force interpolation output size."
)
forward_upsample_size
=
True
# prepare attention_mask
if
attention_mask
is
not
None
:
attention_mask
=
(
1
-
attention_mask
.
to
(
sample
.
dtype
))
*
-
10000.0
attention_mask
=
attention_mask
.
unsqueeze
(
1
)
# center input if necessary
if
self
.
config
.
center_input_sample
:
sample
=
2
*
sample
-
1.0
# time
timesteps
=
timestep
if
not
torch
.
is_tensor
(
timesteps
):
# This would be a good case for the `match` statement (Python 3.10+)
is_mps
=
sample
.
device
.
type
==
"mps"
if
isinstance
(
timestep
,
float
):
dtype
=
torch
.
float32
if
is_mps
else
torch
.
float64
else
:
dtype
=
torch
.
int32
if
is_mps
else
torch
.
int64
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
dtype
,
device
=
sample
.
device
)
elif
len
(
timesteps
.
shape
)
==
0
:
timesteps
=
timesteps
[
None
].
to
(
sample
.
device
)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps
=
timesteps
.
expand
(
sample
.
shape
[
0
])
t_emb
=
self
.
time_proj
(
timesteps
)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb
=
t_emb
.
to
(
dtype
=
self
.
dtype
)
emb
=
self
.
time_embedding
(
t_emb
)
if
self
.
class_embedding
is
not
None
:
if
class_labels
is
None
:
raise
ValueError
(
"class_labels should be provided when num_class_embeds > 0"
)
if
self
.
config
.
class_embed_type
==
"timestep"
:
class_labels
=
self
.
time_proj
(
class_labels
)
class_emb
=
self
.
class_embedding
(
class_labels
).
to
(
dtype
=
self
.
dtype
)
emb
=
emb
+
class_emb
# pre-process
sample
=
self
.
conv_in
(
sample
)
# down
down_block_res_samples
=
(
sample
,)
for
downsample_block
in
self
.
down_blocks
:
if
hasattr
(
downsample_block
,
"has_cross_attention"
)
and
downsample_block
.
has_cross_attention
:
sample
,
res_samples
=
downsample_block
(
hidden_states
=
sample
,
temb
=
emb
,
encoder_hidden_states
=
encoder_hidden_states
,
attention_mask
=
attention_mask
,
)
else
:
sample
,
res_samples
=
downsample_block
(
hidden_states
=
sample
,
temb
=
emb
,
encoder_hidden_states
=
encoder_hidden_states
)
down_block_res_samples
+=
res_samples
# support controlnet
down_block_res_samples
=
list
(
down_block_res_samples
)
if
down_block_additional_residuals
is
not
None
:
for
i
,
down_block_additional_residual
in
enumerate
(
down_block_additional_residuals
):
if
down_block_additional_residual
.
dim
()
==
4
:
# boardcast
down_block_additional_residual
=
down_block_additional_residual
.
unsqueeze
(
2
)
down_block_res_samples
[
i
]
=
down_block_res_samples
[
i
]
+
down_block_additional_residual
# mid
sample
=
self
.
mid_block
(
sample
,
emb
,
encoder_hidden_states
=
encoder_hidden_states
,
attention_mask
=
attention_mask
)
# support controlnet
if
mid_block_additional_residual
is
not
None
:
if
mid_block_additional_residual
.
dim
()
==
4
:
# boardcast
mid_block_additional_residual
=
mid_block_additional_residual
.
unsqueeze
(
2
)
sample
=
sample
+
mid_block_additional_residual
# up
for
i
,
upsample_block
in
enumerate
(
self
.
up_blocks
):
is_final_block
=
i
==
len
(
self
.
up_blocks
)
-
1
res_samples
=
down_block_res_samples
[
-
len
(
upsample_block
.
resnets
)
:]
down_block_res_samples
=
down_block_res_samples
[:
-
len
(
upsample_block
.
resnets
)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if
not
is_final_block
and
forward_upsample_size
:
upsample_size
=
down_block_res_samples
[
-
1
].
shape
[
2
:]
if
hasattr
(
upsample_block
,
"has_cross_attention"
)
and
upsample_block
.
has_cross_attention
:
sample
=
upsample_block
(
hidden_states
=
sample
,
temb
=
emb
,
res_hidden_states_tuple
=
res_samples
,
encoder_hidden_states
=
encoder_hidden_states
,
upsample_size
=
upsample_size
,
attention_mask
=
attention_mask
,
)
else
:
sample
=
upsample_block
(
hidden_states
=
sample
,
temb
=
emb
,
res_hidden_states_tuple
=
res_samples
,
upsample_size
=
upsample_size
,
encoder_hidden_states
=
encoder_hidden_states
,
)
# post-process
sample
=
self
.
conv_norm_out
(
sample
)
sample
=
self
.
conv_act
(
sample
)
sample
=
self
.
conv_out
(
sample
)
if
not
return_dict
:
return
(
sample
,)
return
UNet3DConditionOutput
(
sample
=
sample
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
# If the loaded checkpoint's in_channels or out_channels are different from config
temp_state_dict
=
copy
.
deepcopy
(
state_dict
)
if
temp_state_dict
[
"conv_in.weight"
].
shape
[
1
]
!=
self
.
config
.
in_channels
:
del
temp_state_dict
[
"conv_in.weight"
]
del
temp_state_dict
[
"conv_in.bias"
]
if
temp_state_dict
[
"conv_out.weight"
].
shape
[
0
]
!=
self
.
config
.
out_channels
:
del
temp_state_dict
[
"conv_out.weight"
]
del
temp_state_dict
[
"conv_out.bias"
]
# If the loaded checkpoint's cross_attention_dim is different from config
keys_to_remove
=
[]
for
key
in
temp_state_dict
:
if
"audio_cross_attn.attn.to_k."
in
key
or
"audio_cross_attn.attn.to_v."
in
key
:
if
temp_state_dict
[
key
].
shape
[
1
]
!=
self
.
config
.
cross_attention_dim
:
keys_to_remove
.
append
(
key
)
for
key
in
keys_to_remove
:
del
temp_state_dict
[
key
]
return
super
().
load_state_dict
(
state_dict
=
temp_state_dict
,
strict
=
strict
)
@
classmethod
def
from_pretrained
(
cls
,
model_config
:
dict
,
ckpt_path
:
str
,
device
=
"cpu"
):
unet
=
cls
.
from_config
(
model_config
).
to
(
device
)
if
ckpt_path
!=
""
:
zero_rank_log
(
logger
,
f
"Load from checkpoint:
{
ckpt_path
}
"
)
ckpt
=
torch
.
load
(
ckpt_path
,
map_location
=
device
)
if
"global_step"
in
ckpt
:
zero_rank_log
(
logger
,
f
"resume from global_step:
{
ckpt
[
'global_step'
]
}
"
)
resume_global_step
=
ckpt
[
"global_step"
]
else
:
resume_global_step
=
0
state_dict
=
ckpt
[
"state_dict"
]
if
"state_dict"
in
ckpt
else
ckpt
unet
.
load_state_dict
(
state_dict
,
strict
=
False
)
else
:
resume_global_step
=
0
return
unet
,
resume_global_step
LatentSync/latentsync/models/unet_blocks.py
0 → 100644
View file @
5c023842
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
import
torch
from
torch
import
nn
from
.attention
import
Transformer3DModel
from
.resnet
import
Downsample3D
,
ResnetBlock3D
,
Upsample3D
from
.motion_module
import
get_motion_module
def
get_down_block
(
down_block_type
,
num_layers
,
in_channels
,
out_channels
,
temb_channels
,
add_downsample
,
resnet_eps
,
resnet_act_fn
,
attn_num_head_channels
,
resnet_groups
=
None
,
cross_attention_dim
=
None
,
downsample_padding
=
None
,
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
only_cross_attention
=
False
,
upcast_attention
=
False
,
resnet_time_scale_shift
=
"default"
,
unet_use_cross_frame_attention
=
False
,
unet_use_temporal_attention
=
False
,
use_inflated_groupnorm
=
False
,
use_motion_module
=
None
,
motion_module_type
=
None
,
motion_module_kwargs
=
None
,
add_audio_layer
=
False
,
audio_condition_method
=
"cross_attn"
,
custom_audio_layer
=
False
,
):
down_block_type
=
down_block_type
[
7
:]
if
down_block_type
.
startswith
(
"UNetRes"
)
else
down_block_type
if
down_block_type
==
"DownBlock3D"
:
return
DownBlock3D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_groups
=
resnet_groups
,
downsample_padding
=
downsample_padding
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
use_inflated_groupnorm
=
use_inflated_groupnorm
,
use_motion_module
=
use_motion_module
,
motion_module_type
=
motion_module_type
,
motion_module_kwargs
=
motion_module_kwargs
,
)
elif
down_block_type
==
"CrossAttnDownBlock3D"
:
if
cross_attention_dim
is
None
:
raise
ValueError
(
"cross_attention_dim must be specified for CrossAttnDownBlock3D"
)
return
CrossAttnDownBlock3D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_groups
=
resnet_groups
,
downsample_padding
=
downsample_padding
,
cross_attention_dim
=
cross_attention_dim
,
attn_num_head_channels
=
attn_num_head_channels
,
dual_cross_attention
=
dual_cross_attention
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
,
unet_use_temporal_attention
=
unet_use_temporal_attention
,
use_inflated_groupnorm
=
use_inflated_groupnorm
,
use_motion_module
=
use_motion_module
,
motion_module_type
=
motion_module_type
,
motion_module_kwargs
=
motion_module_kwargs
,
add_audio_layer
=
add_audio_layer
,
audio_condition_method
=
audio_condition_method
,
custom_audio_layer
=
custom_audio_layer
,
)
raise
ValueError
(
f
"
{
down_block_type
}
does not exist."
)
def
get_up_block
(
up_block_type
,
num_layers
,
in_channels
,
out_channels
,
prev_output_channel
,
temb_channels
,
add_upsample
,
resnet_eps
,
resnet_act_fn
,
attn_num_head_channels
,
resnet_groups
=
None
,
cross_attention_dim
=
None
,
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
only_cross_attention
=
False
,
upcast_attention
=
False
,
resnet_time_scale_shift
=
"default"
,
unet_use_cross_frame_attention
=
False
,
unet_use_temporal_attention
=
False
,
use_inflated_groupnorm
=
False
,
use_motion_module
=
None
,
motion_module_type
=
None
,
motion_module_kwargs
=
None
,
add_audio_layer
=
False
,
audio_condition_method
=
"cross_attn"
,
custom_audio_layer
=
False
,
):
up_block_type
=
up_block_type
[
7
:]
if
up_block_type
.
startswith
(
"UNetRes"
)
else
up_block_type
if
up_block_type
==
"UpBlock3D"
:
return
UpBlock3D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_groups
=
resnet_groups
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
use_inflated_groupnorm
=
use_inflated_groupnorm
,
use_motion_module
=
use_motion_module
,
motion_module_type
=
motion_module_type
,
motion_module_kwargs
=
motion_module_kwargs
,
)
elif
up_block_type
==
"CrossAttnUpBlock3D"
:
if
cross_attention_dim
is
None
:
raise
ValueError
(
"cross_attention_dim must be specified for CrossAttnUpBlock3D"
)
return
CrossAttnUpBlock3D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_groups
=
resnet_groups
,
cross_attention_dim
=
cross_attention_dim
,
attn_num_head_channels
=
attn_num_head_channels
,
dual_cross_attention
=
dual_cross_attention
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
,
unet_use_temporal_attention
=
unet_use_temporal_attention
,
use_inflated_groupnorm
=
use_inflated_groupnorm
,
use_motion_module
=
use_motion_module
,
motion_module_type
=
motion_module_type
,
motion_module_kwargs
=
motion_module_kwargs
,
add_audio_layer
=
add_audio_layer
,
audio_condition_method
=
audio_condition_method
,
custom_audio_layer
=
custom_audio_layer
,
)
raise
ValueError
(
f
"
{
up_block_type
}
does not exist."
)
class
UNetMidBlock3DCrossAttn
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
output_scale_factor
=
1.0
,
cross_attention_dim
=
1280
,
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
upcast_attention
=
False
,
unet_use_cross_frame_attention
=
False
,
unet_use_temporal_attention
=
False
,
use_inflated_groupnorm
=
False
,
use_motion_module
=
None
,
motion_module_type
=
None
,
motion_module_kwargs
=
None
,
add_audio_layer
=
False
,
audio_condition_method
=
"cross_attn"
,
custom_audio_layer
:
bool
=
False
,
):
super
().
__init__
()
self
.
has_cross_attention
=
True
self
.
attn_num_head_channels
=
attn_num_head_channels
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
# there is always at least one resnet
resnets
=
[
ResnetBlock3D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
use_inflated_groupnorm
=
use_inflated_groupnorm
,
)
]
attentions
=
[]
audio_attentions
=
[]
motion_modules
=
[]
for
_
in
range
(
num_layers
):
if
dual_cross_attention
:
raise
NotImplementedError
attentions
.
append
(
Transformer3DModel
(
attn_num_head_channels
,
in_channels
//
attn_num_head_channels
,
in_channels
=
in_channels
,
num_layers
=
1
,
cross_attention_dim
=
cross_attention_dim
,
norm_num_groups
=
resnet_groups
,
use_linear_projection
=
use_linear_projection
,
upcast_attention
=
upcast_attention
,
use_motion_module
=
use_motion_module
,
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
,
unet_use_temporal_attention
=
unet_use_temporal_attention
,
add_audio_layer
=
add_audio_layer
,
audio_condition_method
=
audio_condition_method
,
)
)
audio_attentions
.
append
(
Transformer3DModel
(
attn_num_head_channels
,
in_channels
//
attn_num_head_channels
,
in_channels
=
in_channels
,
num_layers
=
1
,
cross_attention_dim
=
cross_attention_dim
,
norm_num_groups
=
resnet_groups
,
use_linear_projection
=
use_linear_projection
,
upcast_attention
=
upcast_attention
,
use_motion_module
=
use_motion_module
,
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
,
unet_use_temporal_attention
=
unet_use_temporal_attention
,
add_audio_layer
=
add_audio_layer
,
audio_condition_method
=
audio_condition_method
,
custom_audio_layer
=
True
,
)
if
custom_audio_layer
else
None
)
motion_modules
.
append
(
get_motion_module
(
in_channels
=
in_channels
,
motion_module_type
=
motion_module_type
,
motion_module_kwargs
=
motion_module_kwargs
,
)
if
use_motion_module
else
None
)
resnets
.
append
(
ResnetBlock3D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
use_inflated_groupnorm
=
use_inflated_groupnorm
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
audio_attentions
=
nn
.
ModuleList
(
audio_attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
motion_modules
=
nn
.
ModuleList
(
motion_modules
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_hidden_states
=
None
,
attention_mask
=
None
):
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
attn
,
audio_attn
,
resnet
,
motion_module
in
zip
(
self
.
attentions
,
self
.
audio_attentions
,
self
.
resnets
[
1
:],
self
.
motion_modules
):
hidden_states
=
attn
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
).
sample
hidden_states
=
(
audio_attn
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
).
sample
if
audio_attn
is
not
None
else
hidden_states
)
hidden_states
=
(
motion_module
(
hidden_states
,
temb
,
encoder_hidden_states
=
encoder_hidden_states
)
if
motion_module
is
not
None
else
hidden_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
return
hidden_states
class
CrossAttnDownBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
cross_attention_dim
=
1280
,
output_scale_factor
=
1.0
,
downsample_padding
=
1
,
add_downsample
=
True
,
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
only_cross_attention
=
False
,
upcast_attention
=
False
,
unet_use_cross_frame_attention
=
False
,
unet_use_temporal_attention
=
False
,
use_inflated_groupnorm
=
False
,
use_motion_module
=
None
,
motion_module_type
=
None
,
motion_module_kwargs
=
None
,
add_audio_layer
=
False
,
audio_condition_method
=
"cross_attn"
,
custom_audio_layer
:
bool
=
False
,
):
super
().
__init__
()
resnets
=
[]
attentions
=
[]
audio_attentions
=
[]
motion_modules
=
[]
self
.
has_cross_attention
=
True
self
.
attn_num_head_channels
=
attn_num_head_channels
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
use_inflated_groupnorm
=
use_inflated_groupnorm
,
)
)
if
dual_cross_attention
:
raise
NotImplementedError
attentions
.
append
(
Transformer3DModel
(
attn_num_head_channels
,
out_channels
//
attn_num_head_channels
,
in_channels
=
out_channels
,
num_layers
=
1
,
cross_attention_dim
=
cross_attention_dim
,
norm_num_groups
=
resnet_groups
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
use_motion_module
=
use_motion_module
,
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
,
unet_use_temporal_attention
=
unet_use_temporal_attention
,
add_audio_layer
=
add_audio_layer
,
audio_condition_method
=
audio_condition_method
,
)
)
audio_attentions
.
append
(
Transformer3DModel
(
attn_num_head_channels
,
out_channels
//
attn_num_head_channels
,
in_channels
=
out_channels
,
num_layers
=
1
,
cross_attention_dim
=
cross_attention_dim
,
norm_num_groups
=
resnet_groups
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
use_motion_module
=
use_motion_module
,
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
,
unet_use_temporal_attention
=
unet_use_temporal_attention
,
add_audio_layer
=
add_audio_layer
,
audio_condition_method
=
audio_condition_method
,
custom_audio_layer
=
True
,
)
if
custom_audio_layer
else
None
)
motion_modules
.
append
(
get_motion_module
(
in_channels
=
out_channels
,
motion_module_type
=
motion_module_type
,
motion_module_kwargs
=
motion_module_kwargs
,
)
if
use_motion_module
else
None
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
audio_attentions
=
nn
.
ModuleList
(
audio_attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
motion_modules
=
nn
.
ModuleList
(
motion_modules
)
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
[
Downsample3D
(
out_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
)
]
)
else
:
self
.
downsamplers
=
None
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_hidden_states
=
None
,
attention_mask
=
None
):
output_states
=
()
for
resnet
,
attn
,
audio_attn
,
motion_module
in
zip
(
self
.
resnets
,
self
.
attentions
,
self
.
audio_attentions
,
self
.
motion_modules
):
if
self
.
training
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
if
return_dict
is
not
None
:
return
module
(
*
inputs
,
return_dict
=
return_dict
)
else
:
return
module
(
*
inputs
)
return
custom_forward
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
resnet
),
hidden_states
,
temb
)
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
attn
,
return_dict
=
False
),
hidden_states
,
encoder_hidden_states
,
)[
0
]
if
motion_module
is
not
None
:
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
motion_module
),
hidden_states
.
requires_grad_
(),
temb
,
encoder_hidden_states
,
)
else
:
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
attn
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
).
sample
hidden_states
=
(
audio_attn
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
).
sample
if
audio_attn
is
not
None
else
hidden_states
)
# add motion module
hidden_states
=
(
motion_module
(
hidden_states
,
temb
,
encoder_hidden_states
=
encoder_hidden_states
)
if
motion_module
is
not
None
else
hidden_states
)
output_states
+=
(
hidden_states
,)
if
self
.
downsamplers
is
not
None
:
for
downsampler
in
self
.
downsamplers
:
hidden_states
=
downsampler
(
hidden_states
)
output_states
+=
(
hidden_states
,)
return
hidden_states
,
output_states
class
DownBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
downsample_padding
=
1
,
use_inflated_groupnorm
=
False
,
use_motion_module
=
None
,
motion_module_type
=
None
,
motion_module_kwargs
=
None
,
):
super
().
__init__
()
resnets
=
[]
motion_modules
=
[]
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
use_inflated_groupnorm
=
use_inflated_groupnorm
,
)
)
motion_modules
.
append
(
get_motion_module
(
in_channels
=
out_channels
,
motion_module_type
=
motion_module_type
,
motion_module_kwargs
=
motion_module_kwargs
,
)
if
use_motion_module
else
None
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
motion_modules
=
nn
.
ModuleList
(
motion_modules
)
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
[
Downsample3D
(
out_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
)
]
)
else
:
self
.
downsamplers
=
None
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_hidden_states
=
None
):
output_states
=
()
for
resnet
,
motion_module
in
zip
(
self
.
resnets
,
self
.
motion_modules
):
if
self
.
training
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
)
return
custom_forward
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
resnet
),
hidden_states
,
temb
)
if
motion_module
is
not
None
:
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
motion_module
),
hidden_states
.
requires_grad_
(),
temb
,
encoder_hidden_states
,
)
else
:
hidden_states
=
resnet
(
hidden_states
,
temb
)
# add motion module
hidden_states
=
(
motion_module
(
hidden_states
,
temb
,
encoder_hidden_states
=
encoder_hidden_states
)
if
motion_module
is
not
None
else
hidden_states
)
output_states
+=
(
hidden_states
,)
if
self
.
downsamplers
is
not
None
:
for
downsampler
in
self
.
downsamplers
:
hidden_states
=
downsampler
(
hidden_states
)
output_states
+=
(
hidden_states
,)
return
hidden_states
,
output_states
class
CrossAttnUpBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
prev_output_channel
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
cross_attention_dim
=
1280
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
only_cross_attention
=
False
,
upcast_attention
=
False
,
unet_use_cross_frame_attention
=
False
,
unet_use_temporal_attention
=
False
,
use_inflated_groupnorm
=
False
,
use_motion_module
=
None
,
motion_module_type
=
None
,
motion_module_kwargs
=
None
,
add_audio_layer
=
False
,
audio_condition_method
=
"cross_attn"
,
custom_audio_layer
=
False
,
):
super
().
__init__
()
resnets
=
[]
attentions
=
[]
audio_attentions
=
[]
motion_modules
=
[]
self
.
has_cross_attention
=
True
self
.
attn_num_head_channels
=
attn_num_head_channels
for
i
in
range
(
num_layers
):
res_skip_channels
=
in_channels
if
(
i
==
num_layers
-
1
)
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock3D
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
use_inflated_groupnorm
=
use_inflated_groupnorm
,
)
)
if
dual_cross_attention
:
raise
NotImplementedError
attentions
.
append
(
Transformer3DModel
(
attn_num_head_channels
,
out_channels
//
attn_num_head_channels
,
in_channels
=
out_channels
,
num_layers
=
1
,
cross_attention_dim
=
cross_attention_dim
,
norm_num_groups
=
resnet_groups
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
use_motion_module
=
use_motion_module
,
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
,
unet_use_temporal_attention
=
unet_use_temporal_attention
,
add_audio_layer
=
add_audio_layer
,
audio_condition_method
=
audio_condition_method
,
)
)
audio_attentions
.
append
(
Transformer3DModel
(
attn_num_head_channels
,
out_channels
//
attn_num_head_channels
,
in_channels
=
out_channels
,
num_layers
=
1
,
cross_attention_dim
=
cross_attention_dim
,
norm_num_groups
=
resnet_groups
,
use_linear_projection
=
use_linear_projection
,
only_cross_attention
=
only_cross_attention
,
upcast_attention
=
upcast_attention
,
use_motion_module
=
use_motion_module
,
unet_use_cross_frame_attention
=
unet_use_cross_frame_attention
,
unet_use_temporal_attention
=
unet_use_temporal_attention
,
add_audio_layer
=
add_audio_layer
,
audio_condition_method
=
audio_condition_method
,
custom_audio_layer
=
True
,
)
if
custom_audio_layer
else
None
)
motion_modules
.
append
(
get_motion_module
(
in_channels
=
out_channels
,
motion_module_type
=
motion_module_type
,
motion_module_kwargs
=
motion_module_kwargs
,
)
if
use_motion_module
else
None
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
audio_attentions
=
nn
.
ModuleList
(
audio_attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
motion_modules
=
nn
.
ModuleList
(
motion_modules
)
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample3D
(
out_channels
,
use_conv
=
True
,
out_channels
=
out_channels
)])
else
:
self
.
upsamplers
=
None
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
encoder_hidden_states
=
None
,
upsample_size
=
None
,
attention_mask
=
None
,
):
for
resnet
,
attn
,
audio_attn
,
motion_module
in
zip
(
self
.
resnets
,
self
.
attentions
,
self
.
audio_attentions
,
self
.
motion_modules
):
# pop res hidden states
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
if
return_dict
is
not
None
:
return
module
(
*
inputs
,
return_dict
=
return_dict
)
else
:
return
module
(
*
inputs
)
return
custom_forward
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
resnet
),
hidden_states
,
temb
)
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
attn
,
return_dict
=
False
),
hidden_states
,
encoder_hidden_states
,
)[
0
]
if
motion_module
is
not
None
:
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
motion_module
),
hidden_states
.
requires_grad_
(),
temb
,
encoder_hidden_states
,
)
else
:
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
attn
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
).
sample
hidden_states
=
(
audio_attn
(
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
).
sample
if
audio_attn
is
not
None
else
hidden_states
)
# add motion module
hidden_states
=
(
motion_module
(
hidden_states
,
temb
,
encoder_hidden_states
=
encoder_hidden_states
)
if
motion_module
is
not
None
else
hidden_states
)
if
self
.
upsamplers
is
not
None
:
for
upsampler
in
self
.
upsamplers
:
hidden_states
=
upsampler
(
hidden_states
,
upsample_size
)
return
hidden_states
class
UpBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
prev_output_channel
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
use_inflated_groupnorm
=
False
,
use_motion_module
=
None
,
motion_module_type
=
None
,
motion_module_kwargs
=
None
,
):
super
().
__init__
()
resnets
=
[]
motion_modules
=
[]
for
i
in
range
(
num_layers
):
res_skip_channels
=
in_channels
if
(
i
==
num_layers
-
1
)
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock3D
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
use_inflated_groupnorm
=
use_inflated_groupnorm
,
)
)
motion_modules
.
append
(
get_motion_module
(
in_channels
=
out_channels
,
motion_module_type
=
motion_module_type
,
motion_module_kwargs
=
motion_module_kwargs
,
)
if
use_motion_module
else
None
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
motion_modules
=
nn
.
ModuleList
(
motion_modules
)
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample3D
(
out_channels
,
use_conv
=
True
,
out_channels
=
out_channels
)])
else
:
self
.
upsamplers
=
None
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
upsample_size
=
None
,
encoder_hidden_states
=
None
,
):
for
resnet
,
motion_module
in
zip
(
self
.
resnets
,
self
.
motion_modules
):
# pop res hidden states
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
)
return
custom_forward
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
resnet
),
hidden_states
,
temb
)
if
motion_module
is
not
None
:
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
motion_module
),
hidden_states
.
requires_grad_
(),
temb
,
encoder_hidden_states
,
)
else
:
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
(
motion_module
(
hidden_states
,
temb
,
encoder_hidden_states
=
encoder_hidden_states
)
if
motion_module
is
not
None
else
hidden_states
)
if
self
.
upsamplers
is
not
None
:
for
upsampler
in
self
.
upsamplers
:
hidden_states
=
upsampler
(
hidden_states
,
upsample_size
)
return
hidden_states
LatentSync/latentsync/models/utils.py
0 → 100644
View file @
5c023842
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def
zero_module
(
module
):
# Zero out the parameters of a module and return it.
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
LatentSync/latentsync/pipelines/lipsync_pipeline.py
0 → 100644
View file @
5c023842
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/pipelines/pipeline_animation.py
import
inspect
import
os
import
shutil
from
typing
import
Callable
,
List
,
Optional
,
Union
import
subprocess
import
numpy
as
np
import
torch
import
torchvision
from
diffusers.utils
import
is_accelerate_available
from
packaging
import
version
from
diffusers.configuration_utils
import
FrozenDict
from
diffusers.models
import
AutoencoderKL
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.schedulers
import
(
DDIMScheduler
,
DPMSolverMultistepScheduler
,
EulerAncestralDiscreteScheduler
,
EulerDiscreteScheduler
,
LMSDiscreteScheduler
,
PNDMScheduler
,
)
from
diffusers.utils
import
deprecate
,
logging
from
einops
import
rearrange
from
..models.unet
import
UNet3DConditionModel
from
..utils.image_processor
import
ImageProcessor
from
..utils.util
import
read_video
,
read_audio
,
write_video
from
..whisper.audio2feature
import
Audio2Feature
import
tqdm
import
soundfile
as
sf
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
class
LipsyncPipeline
(
DiffusionPipeline
):
_optional_components
=
[]
def
__init__
(
self
,
vae
:
AutoencoderKL
,
audio_encoder
:
Audio2Feature
,
unet
:
UNet3DConditionModel
,
scheduler
:
Union
[
DDIMScheduler
,
PNDMScheduler
,
LMSDiscreteScheduler
,
EulerDiscreteScheduler
,
EulerAncestralDiscreteScheduler
,
DPMSolverMultistepScheduler
,
],
):
super
().
__init__
()
if
hasattr
(
scheduler
.
config
,
"steps_offset"
)
and
scheduler
.
config
.
steps_offset
!=
1
:
deprecation_message
=
(
f
"The configuration file of this scheduler:
{
scheduler
}
is outdated. `steps_offset`"
f
" should be set to 1 instead of
{
scheduler
.
config
.
steps_offset
}
. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate
(
"steps_offset!=1"
,
"1.0.0"
,
deprecation_message
,
standard_warn
=
False
)
new_config
=
dict
(
scheduler
.
config
)
new_config
[
"steps_offset"
]
=
1
scheduler
.
_internal_dict
=
FrozenDict
(
new_config
)
if
hasattr
(
scheduler
.
config
,
"clip_sample"
)
and
scheduler
.
config
.
clip_sample
is
True
:
deprecation_message
=
(
f
"The configuration file of this scheduler:
{
scheduler
}
has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate
(
"clip_sample not set"
,
"1.0.0"
,
deprecation_message
,
standard_warn
=
False
)
new_config
=
dict
(
scheduler
.
config
)
new_config
[
"clip_sample"
]
=
False
scheduler
.
_internal_dict
=
FrozenDict
(
new_config
)
is_unet_version_less_0_9_0
=
hasattr
(
unet
.
config
,
"_diffusers_version"
)
and
version
.
parse
(
version
.
parse
(
unet
.
config
.
_diffusers_version
).
base_version
)
<
version
.
parse
(
"0.9.0.dev0"
)
is_unet_sample_size_less_64
=
hasattr
(
unet
.
config
,
"sample_size"
)
and
unet
.
config
.
sample_size
<
64
if
is_unet_version_less_0_9_0
and
is_unet_sample_size_less_64
:
deprecation_message
=
(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following:
\n
- CompVis/stable-diffusion-v1-4
\n
- CompVis/stable-diffusion-v1-3
\n
-"
" CompVis/stable-diffusion-v1-2
\n
- CompVis/stable-diffusion-v1-1
\n
- runwayml/stable-diffusion-v1-5"
"
\n
- runwayml/stable-diffusion-inpainting
\n
you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate
(
"sample_size<64"
,
"1.0.0"
,
deprecation_message
,
standard_warn
=
False
)
new_config
=
dict
(
unet
.
config
)
new_config
[
"sample_size"
]
=
64
unet
.
_internal_dict
=
FrozenDict
(
new_config
)
self
.
register_modules
(
vae
=
vae
,
audio_encoder
=
audio_encoder
,
unet
=
unet
,
scheduler
=
scheduler
,
)
self
.
vae_scale_factor
=
2
**
(
len
(
self
.
vae
.
config
.
block_out_channels
)
-
1
)
self
.
set_progress_bar_config
(
desc
=
"Steps"
)
def
enable_vae_slicing
(
self
):
self
.
vae
.
enable_slicing
()
def
disable_vae_slicing
(
self
):
self
.
vae
.
disable_slicing
()
def
enable_sequential_cpu_offload
(
self
,
gpu_id
=
0
):
if
is_accelerate_available
():
from
accelerate
import
cpu_offload
else
:
raise
ImportError
(
"Please install accelerate via `pip install accelerate`"
)
device
=
torch
.
device
(
f
"cuda:
{
gpu_id
}
"
)
for
cpu_offloaded_model
in
[
self
.
unet
,
self
.
text_encoder
,
self
.
vae
]:
if
cpu_offloaded_model
is
not
None
:
cpu_offload
(
cpu_offloaded_model
,
device
)
@
property
def
_execution_device
(
self
):
if
self
.
device
!=
torch
.
device
(
"meta"
)
or
not
hasattr
(
self
.
unet
,
"_hf_hook"
):
return
self
.
device
for
module
in
self
.
unet
.
modules
():
if
(
hasattr
(
module
,
"_hf_hook"
)
and
hasattr
(
module
.
_hf_hook
,
"execution_device"
)
and
module
.
_hf_hook
.
execution_device
is
not
None
):
return
torch
.
device
(
module
.
_hf_hook
.
execution_device
)
return
self
.
device
def
decode_latents
(
self
,
latents
):
latents
=
latents
/
self
.
vae
.
config
.
scaling_factor
+
self
.
vae
.
config
.
shift_factor
latents
=
rearrange
(
latents
,
"b c f h w -> (b f) c h w"
)
decoded_latents
=
self
.
vae
.
decode
(
latents
).
sample
return
decoded_latents
def
prepare_extra_step_kwargs
(
self
,
generator
,
eta
):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta
=
"eta"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
extra_step_kwargs
=
{}
if
accepts_eta
:
extra_step_kwargs
[
"eta"
]
=
eta
# check if the scheduler accepts generator
accepts_generator
=
"generator"
in
set
(
inspect
.
signature
(
self
.
scheduler
.
step
).
parameters
.
keys
())
if
accepts_generator
:
extra_step_kwargs
[
"generator"
]
=
generator
return
extra_step_kwargs
def
check_inputs
(
self
,
height
,
width
,
callback_steps
):
assert
height
==
width
,
"Height and width must be equal"
if
height
%
8
!=
0
or
width
%
8
!=
0
:
raise
ValueError
(
f
"`height` and `width` have to be divisible by 8 but are
{
height
}
and
{
width
}
."
)
if
(
callback_steps
is
None
)
or
(
callback_steps
is
not
None
and
(
not
isinstance
(
callback_steps
,
int
)
or
callback_steps
<=
0
)
):
raise
ValueError
(
f
"`callback_steps` has to be a positive integer but is
{
callback_steps
}
of type"
f
"
{
type
(
callback_steps
)
}
."
)
def
prepare_latents
(
self
,
batch_size
,
num_frames
,
num_channels_latents
,
height
,
width
,
dtype
,
device
,
generator
):
shape
=
(
batch_size
,
num_channels_latents
,
1
,
height
//
self
.
vae_scale_factor
,
width
//
self
.
vae_scale_factor
,
)
rand_device
=
"cpu"
if
device
.
type
==
"mps"
else
device
latents
=
torch
.
randn
(
shape
,
generator
=
generator
,
device
=
rand_device
,
dtype
=
dtype
).
to
(
device
)
latents
=
latents
.
repeat
(
1
,
1
,
num_frames
,
1
,
1
)
# scale the initial noise by the standard deviation required by the scheduler
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
return
latents
def
prepare_mask_latents
(
self
,
mask
,
masked_image
,
height
,
width
,
dtype
,
device
,
generator
,
do_classifier_free_guidance
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask
=
torch
.
nn
.
functional
.
interpolate
(
mask
,
size
=
(
height
//
self
.
vae_scale_factor
,
width
//
self
.
vae_scale_factor
)
)
masked_image
=
masked_image
.
to
(
device
=
device
,
dtype
=
dtype
)
# encode the mask image into latents space so we can concatenate it to the latents
masked_image_latents
=
self
.
vae
.
encode
(
masked_image
).
latent_dist
.
sample
(
generator
=
generator
)
masked_image_latents
=
(
masked_image_latents
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents
=
masked_image_latents
.
to
(
device
=
device
,
dtype
=
dtype
)
mask
=
mask
.
to
(
device
=
device
,
dtype
=
dtype
)
# assume batch size = 1
mask
=
rearrange
(
mask
,
"f c h w -> 1 c f h w"
)
masked_image_latents
=
rearrange
(
masked_image_latents
,
"f c h w -> 1 c f h w"
)
mask
=
torch
.
cat
([
mask
]
*
2
)
if
do_classifier_free_guidance
else
mask
masked_image_latents
=
(
torch
.
cat
([
masked_image_latents
]
*
2
)
if
do_classifier_free_guidance
else
masked_image_latents
)
return
mask
,
masked_image_latents
def
prepare_image_latents
(
self
,
images
,
device
,
dtype
,
generator
,
do_classifier_free_guidance
):
images
=
images
.
to
(
device
=
device
,
dtype
=
dtype
)
image_latents
=
self
.
vae
.
encode
(
images
).
latent_dist
.
sample
(
generator
=
generator
)
image_latents
=
(
image_latents
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
image_latents
=
rearrange
(
image_latents
,
"f c h w -> 1 c f h w"
)
image_latents
=
torch
.
cat
([
image_latents
]
*
2
)
if
do_classifier_free_guidance
else
image_latents
return
image_latents
def
set_progress_bar_config
(
self
,
**
kwargs
):
if
not
hasattr
(
self
,
"_progress_bar_config"
):
self
.
_progress_bar_config
=
{}
self
.
_progress_bar_config
.
update
(
kwargs
)
@
staticmethod
def
paste_surrounding_pixels_back
(
decoded_latents
,
pixel_values
,
masks
,
device
,
weight_dtype
):
# Paste the surrounding pixels back, because we only want to change the mouth region
pixel_values
=
pixel_values
.
to
(
device
=
device
,
dtype
=
weight_dtype
)
masks
=
masks
.
to
(
device
=
device
,
dtype
=
weight_dtype
)
combined_pixel_values
=
decoded_latents
*
masks
+
pixel_values
*
(
1
-
masks
)
return
combined_pixel_values
@
staticmethod
def
pixel_values_to_images
(
pixel_values
:
torch
.
Tensor
):
pixel_values
=
rearrange
(
pixel_values
,
"f c h w -> f h w c"
)
pixel_values
=
(
pixel_values
/
2
+
0.5
).
clamp
(
0
,
1
)
images
=
(
pixel_values
*
255
).
to
(
torch
.
uint8
)
images
=
images
.
cpu
().
numpy
()
return
images
def
affine_transform_video
(
self
,
video_path
):
video_frames
=
read_video
(
video_path
,
use_decord
=
False
)
faces
=
[]
boxes
=
[]
affine_matrices
=
[]
print
(
f
"Affine transforming
{
len
(
video_frames
)
}
faces..."
)
for
frame
in
tqdm
.
tqdm
(
video_frames
):
face
,
box
,
affine_matrix
=
self
.
image_processor
.
affine_transform
(
frame
)
faces
.
append
(
face
)
boxes
.
append
(
box
)
affine_matrices
.
append
(
affine_matrix
)
faces
=
torch
.
stack
(
faces
)
return
faces
,
video_frames
,
boxes
,
affine_matrices
def
restore_video
(
self
,
faces
,
video_frames
,
boxes
,
affine_matrices
):
video_frames
=
video_frames
[:
faces
.
shape
[
0
]]
out_frames
=
[]
for
index
,
face
in
enumerate
(
faces
):
x1
,
y1
,
x2
,
y2
=
boxes
[
index
]
height
=
int
(
y2
-
y1
)
width
=
int
(
x2
-
x1
)
face
=
torchvision
.
transforms
.
functional
.
resize
(
face
,
size
=
(
height
,
width
),
antialias
=
True
)
face
=
rearrange
(
face
,
"c h w -> h w c"
)
face
=
(
face
/
2
+
0.5
).
clamp
(
0
,
1
)
face
=
(
face
*
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
out_frame
=
self
.
image_processor
.
restorer
.
restore_img
(
video_frames
[
index
],
face
,
affine_matrices
[
index
])
out_frames
.
append
(
out_frame
)
return
np
.
stack
(
out_frames
,
axis
=
0
)
@
torch
.
no_grad
()
def
__call__
(
self
,
video_path
:
str
,
audio_path
:
str
,
video_out_path
:
str
,
video_mask_path
:
str
=
None
,
num_frames
:
int
=
16
,
video_fps
:
int
=
25
,
audio_sample_rate
:
int
=
16000
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
num_inference_steps
:
int
=
20
,
guidance_scale
:
float
=
1.5
,
weight_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
float16
,
eta
:
float
=
0.0
,
mask
:
str
=
"fix_mask"
,
generator
:
Optional
[
Union
[
torch
.
Generator
,
List
[
torch
.
Generator
]]]
=
None
,
callback
:
Optional
[
Callable
[[
int
,
int
,
torch
.
FloatTensor
],
None
]]
=
None
,
callback_steps
:
Optional
[
int
]
=
1
,
**
kwargs
,
):
is_train
=
self
.
unet
.
training
self
.
unet
.
eval
()
# 0. Define call parameters
batch_size
=
1
device
=
self
.
_execution_device
self
.
image_processor
=
ImageProcessor
(
height
,
mask
=
mask
,
device
=
"cuda"
)
self
.
set_progress_bar_config
(
desc
=
f
"Sample frames:
{
num_frames
}
"
)
video_frames
,
original_video_frames
,
boxes
,
affine_matrices
=
self
.
affine_transform_video
(
video_path
)
audio_samples
=
read_audio
(
audio_path
)
# 1. Default height and width to unet
height
=
height
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
width
=
width
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
# 2. Check inputs
self
.
check_inputs
(
height
,
width
,
callback_steps
)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance
=
guidance_scale
>
1.0
# 3. set timesteps
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
device
=
device
)
timesteps
=
self
.
scheduler
.
timesteps
# 4. Prepare extra step kwargs.
extra_step_kwargs
=
self
.
prepare_extra_step_kwargs
(
generator
,
eta
)
self
.
video_fps
=
video_fps
if
self
.
unet
.
add_audio_layer
:
whisper_feature
=
self
.
audio_encoder
.
audio2feat
(
audio_path
)
whisper_chunks
=
self
.
audio_encoder
.
feature2chunks
(
feature_array
=
whisper_feature
,
fps
=
video_fps
)
num_inferences
=
min
(
len
(
video_frames
),
len
(
whisper_chunks
))
//
num_frames
else
:
num_inferences
=
len
(
video_frames
)
//
num_frames
synced_video_frames
=
[]
masked_video_frames
=
[]
num_channels_latents
=
self
.
vae
.
config
.
latent_channels
# Prepare latent variables
all_latents
=
self
.
prepare_latents
(
batch_size
,
num_frames
*
num_inferences
,
num_channels_latents
,
height
,
width
,
weight_dtype
,
device
,
generator
,
)
for
i
in
tqdm
.
tqdm
(
range
(
num_inferences
),
desc
=
"Doing inference..."
):
if
self
.
unet
.
add_audio_layer
:
audio_embeds
=
torch
.
stack
(
whisper_chunks
[
i
*
num_frames
:
(
i
+
1
)
*
num_frames
])
audio_embeds
=
audio_embeds
.
to
(
device
,
dtype
=
weight_dtype
)
if
do_classifier_free_guidance
:
empty_audio_embeds
=
torch
.
zeros_like
(
audio_embeds
)
audio_embeds
=
torch
.
cat
([
empty_audio_embeds
,
audio_embeds
])
else
:
audio_embeds
=
None
inference_video_frames
=
video_frames
[
i
*
num_frames
:
(
i
+
1
)
*
num_frames
]
latents
=
all_latents
[:,
:,
i
*
num_frames
:
(
i
+
1
)
*
num_frames
]
pixel_values
,
masked_pixel_values
,
masks
=
self
.
image_processor
.
prepare_masks_and_masked_images
(
inference_video_frames
,
affine_transform
=
False
)
# 7. Prepare mask latent variables
mask_latents
,
masked_image_latents
=
self
.
prepare_mask_latents
(
masks
,
masked_pixel_values
,
height
,
width
,
weight_dtype
,
device
,
generator
,
do_classifier_free_guidance
,
)
# 8. Prepare image latents
image_latents
=
self
.
prepare_image_latents
(
pixel_values
,
device
,
weight_dtype
,
generator
,
do_classifier_free_guidance
,
)
# 9. Denoising loop
num_warmup_steps
=
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
with
self
.
progress_bar
(
total
=
num_inference_steps
)
as
progress_bar
:
for
j
,
t
in
enumerate
(
timesteps
):
# expand the latents if we are doing classifier free guidance
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input
=
self
.
scheduler
.
scale_model_input
(
latent_model_input
,
t
)
latent_model_input
=
torch
.
cat
(
[
latent_model_input
,
mask_latents
,
masked_image_latents
,
image_latents
],
dim
=
1
)
# predict the noise residual
noise_pred
=
self
.
unet
(
latent_model_input
,
t
,
encoder_hidden_states
=
audio_embeds
).
sample
# perform guidance
if
do_classifier_free_guidance
:
noise_pred_uncond
,
noise_pred_audio
=
noise_pred
.
chunk
(
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_audio
-
noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
).
prev_sample
# call the callback, if provided
if
j
==
len
(
timesteps
)
-
1
or
((
j
+
1
)
>
num_warmup_steps
and
(
j
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
callback
is
not
None
and
j
%
callback_steps
==
0
:
callback
(
j
,
t
,
latents
)
# Recover the pixel values
decoded_latents
=
self
.
decode_latents
(
latents
)
decoded_latents
=
self
.
paste_surrounding_pixels_back
(
decoded_latents
,
pixel_values
,
1
-
masks
,
device
,
weight_dtype
)
synced_video_frames
.
append
(
decoded_latents
)
masked_video_frames
.
append
(
masked_pixel_values
)
synced_video_frames
=
self
.
restore_video
(
torch
.
cat
(
synced_video_frames
),
original_video_frames
,
boxes
,
affine_matrices
)
masked_video_frames
=
self
.
restore_video
(
torch
.
cat
(
masked_video_frames
),
original_video_frames
,
boxes
,
affine_matrices
)
audio_samples_remain_length
=
int
(
synced_video_frames
.
shape
[
0
]
/
video_fps
*
audio_sample_rate
)
audio_samples
=
audio_samples
[:
audio_samples_remain_length
].
cpu
().
numpy
()
if
is_train
:
self
.
unet
.
train
()
temp_dir
=
"temp"
if
os
.
path
.
exists
(
temp_dir
):
shutil
.
rmtree
(
temp_dir
)
os
.
makedirs
(
temp_dir
,
exist_ok
=
True
)
write_video
(
os
.
path
.
join
(
temp_dir
,
"video.mp4"
),
synced_video_frames
,
fps
=
25
)
# write_video(video_mask_path, masked_video_frames, fps=25)
sf
.
write
(
os
.
path
.
join
(
temp_dir
,
"audio.wav"
),
audio_samples
,
audio_sample_rate
)
command
=
f
"ffmpeg -y -loglevel error -nostdin -i
{
os
.
path
.
join
(
temp_dir
,
'video.mp4'
)
}
-i
{
os
.
path
.
join
(
temp_dir
,
'audio.wav'
)
}
-c:v libx264 -c:a aac -q:v 0 -q:a 0
{
video_out_path
}
"
subprocess
.
run
(
command
,
shell
=
True
)
LatentSync/latentsync/trepa/__init__.py
0 → 100644
View file @
5c023842
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn.functional
as
F
import
torch.nn
as
nn
from
einops
import
rearrange
from
.third_party.VideoMAEv2.utils
import
load_videomae_model
class
TREPALoss
:
def
__init__
(
self
,
device
=
"cuda"
,
ckpt_path
=
"/mnt/bn/maliva-gen-ai-v2/chunyu.li/checkpoints/vit_g_hybrid_pt_1200e_ssv2_ft.pth"
,
):
self
.
model
=
load_videomae_model
(
device
,
ckpt_path
).
eval
().
to
(
dtype
=
torch
.
float16
)
self
.
model
.
requires_grad_
(
False
)
self
.
bce_loss
=
nn
.
BCELoss
()
def
__call__
(
self
,
videos_fake
,
videos_real
,
loss_type
=
"mse"
):
batch_size
=
videos_fake
.
shape
[
0
]
num_frames
=
videos_fake
.
shape
[
2
]
videos_fake
=
rearrange
(
videos_fake
.
clone
(),
"b c f h w -> (b f) c h w"
)
videos_real
=
rearrange
(
videos_real
.
clone
(),
"b c f h w -> (b f) c h w"
)
videos_fake
=
F
.
interpolate
(
videos_fake
,
size
=
(
224
,
224
),
mode
=
"bilinear"
)
videos_real
=
F
.
interpolate
(
videos_real
,
size
=
(
224
,
224
),
mode
=
"bilinear"
)
videos_fake
=
rearrange
(
videos_fake
,
"(b f) c h w -> b c f h w"
,
f
=
num_frames
)
videos_real
=
rearrange
(
videos_real
,
"(b f) c h w -> b c f h w"
,
f
=
num_frames
)
# Because input pixel range is [-1, 1], and model expects pixel range to be [0, 1]
videos_fake
=
(
videos_fake
/
2
+
0.5
).
clamp
(
0
,
1
)
videos_real
=
(
videos_real
/
2
+
0.5
).
clamp
(
0
,
1
)
feats_fake
=
self
.
model
.
forward_features
(
videos_fake
)
feats_real
=
self
.
model
.
forward_features
(
videos_real
)
feats_fake
=
F
.
normalize
(
feats_fake
,
p
=
2
,
dim
=
1
)
feats_real
=
F
.
normalize
(
feats_real
,
p
=
2
,
dim
=
1
)
return
F
.
mse_loss
(
feats_fake
,
feats_real
)
if
__name__
==
"__main__"
:
# input shape: (b, c, f, h, w)
videos_fake
=
torch
.
randn
(
2
,
3
,
16
,
256
,
256
,
requires_grad
=
True
).
to
(
device
=
"cuda"
,
dtype
=
torch
.
float16
)
videos_real
=
torch
.
randn
(
2
,
3
,
16
,
256
,
256
,
requires_grad
=
True
).
to
(
device
=
"cuda"
,
dtype
=
torch
.
float16
)
trepa_loss
=
TREPALoss
(
device
=
"cuda"
)
loss
=
trepa_loss
(
videos_fake
,
videos_real
)
print
(
loss
)
LatentSync/latentsync/trepa/third_party/VideoMAEv2/__init__.py
0 → 100644
View file @
5c023842
LatentSync/latentsync/trepa/third_party/VideoMAEv2/utils.py
0 → 100644
View file @
5c023842
import
os
import
torch
import
requests
from
tqdm
import
tqdm
from
torchvision
import
transforms
from
.videomaev2_finetune
import
vit_giant_patch14_224
def
to_normalized_float_tensor
(
vid
):
return
vid
.
permute
(
3
,
0
,
1
,
2
).
to
(
torch
.
float32
)
/
255
# NOTE: for those functions, which generally expect mini-batches, we keep them
# as non-minibatch so that they are applied as if they were 4d (thus image).
# this way, we only apply the transformation in the spatial domain
def
resize
(
vid
,
size
,
interpolation
=
'bilinear'
):
# NOTE: using bilinear interpolation because we don't work on minibatches
# at this level
scale
=
None
if
isinstance
(
size
,
int
):
scale
=
float
(
size
)
/
min
(
vid
.
shape
[
-
2
:])
size
=
None
return
torch
.
nn
.
functional
.
interpolate
(
vid
,
size
=
size
,
scale_factor
=
scale
,
mode
=
interpolation
,
align_corners
=
False
)
class
ToFloatTensorInZeroOne
(
object
):
def
__call__
(
self
,
vid
):
return
to_normalized_float_tensor
(
vid
)
class
Resize
(
object
):
def
__init__
(
self
,
size
):
self
.
size
=
size
def
__call__
(
self
,
vid
):
return
resize
(
vid
,
self
.
size
)
def
preprocess_videomae
(
videos
):
transform
=
transforms
.
Compose
(
[
ToFloatTensorInZeroOne
(),
Resize
((
224
,
224
))])
return
torch
.
stack
([
transform
(
f
)
for
f
in
torch
.
from_numpy
(
videos
)])
def
load_videomae_model
(
device
,
ckpt_path
=
None
):
if
ckpt_path
is
None
:
current_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
ckpt_path
=
os
.
path
.
join
(
current_dir
,
'vit_g_hybrid_pt_1200e_ssv2_ft.pth'
)
if
not
os
.
path
.
exists
(
ckpt_path
):
# download the ckpt to the path
ckpt_url
=
'https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/videomaev2/vit_g_hybrid_pt_1200e_ssv2_ft.pth'
response
=
requests
.
get
(
ckpt_url
,
stream
=
True
,
allow_redirects
=
True
)
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
block_size
=
1024
with
tqdm
(
total
=
total_size
,
unit
=
"B"
,
unit_scale
=
True
)
as
progress_bar
:
with
open
(
ckpt_path
,
"wb"
)
as
fw
:
for
data
in
response
.
iter_content
(
block_size
):
progress_bar
.
update
(
len
(
data
))
fw
.
write
(
data
)
model
=
vit_giant_patch14_224
(
img_size
=
224
,
pretrained
=
False
,
num_classes
=
174
,
all_frames
=
16
,
tubelet_size
=
2
,
drop_path_rate
=
0.3
,
use_mean_pooling
=
True
)
ckpt
=
torch
.
load
(
ckpt_path
,
map_location
=
'cpu'
)
for
model_key
in
[
'model'
,
'module'
]:
if
model_key
in
ckpt
:
ckpt
=
ckpt
[
model_key
]
break
model
.
load_state_dict
(
ckpt
)
return
model
.
to
(
device
)
\ No newline at end of file
LatentSync/latentsync/trepa/third_party/VideoMAEv2/videomaev2_finetune.py
0 → 100644
View file @
5c023842
# --------------------------------------------------------
# Based on BEiT, timm, DINO and DeiT code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
from
functools
import
partial
import
math
import
warnings
import
numpy
as
np
import
collections.abc
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
as
cp
from
itertools
import
repeat
def
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def
norm_cdf
(
x
):
# Computes standard normal cumulative distribution function
return
(
1.0
+
math
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
/
2.0
if
(
mean
<
a
-
2
*
std
)
or
(
mean
>
b
+
2
*
std
):
warnings
.
warn
(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
,
stacklevel
=
2
,
)
with
torch
.
no_grad
():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l
=
norm_cdf
((
a
-
mean
)
/
std
)
u
=
norm_cdf
((
b
-
mean
)
/
std
)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor
.
uniform_
(
2
*
l
-
1
,
2
*
u
-
1
)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor
.
erfinv_
()
# Transform to proper mean, std
tensor
.
mul_
(
std
*
math
.
sqrt
(
2.0
))
tensor
.
add_
(
mean
)
# Clamp to ensure it's in the proper range
tensor
.
clamp_
(
min
=
a
,
max
=
b
)
return
tensor
def
trunc_normal_
(
tensor
,
mean
=
0.0
,
std
=
1.0
,
a
=-
2.0
,
b
=
2.0
):
r
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
)
def
_ntuple
(
n
):
def
parse
(
x
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
):
return
x
return
tuple
(
repeat
(
x
,
n
))
return
parse
to_2tuple
=
_ntuple
(
2
)
def
drop_path
(
x
,
drop_prob
:
float
=
0.0
,
training
:
bool
=
False
):
"""
Adapted from timm codebase
"""
if
drop_prob
==
0.0
or
not
training
:
return
x
keep_prob
=
1
-
drop_prob
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
# work with diff dim tensors, not just 2D ConvNets
random_tensor
=
keep_prob
+
torch
.
rand
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
x
.
div
(
keep_prob
)
*
random_tensor
return
output
def
_cfg
(
url
=
""
,
**
kwargs
):
return
{
"url"
:
url
,
"num_classes"
:
400
,
"input_size"
:
(
3
,
224
,
224
),
"pool_size"
:
None
,
"crop_pct"
:
0.9
,
"interpolation"
:
"bicubic"
,
"mean"
:
(
0.5
,
0.5
,
0.5
),
"std"
:
(
0.5
,
0.5
,
0.5
),
**
kwargs
,
}
class
DropPath
(
nn
.
Module
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
def
extra_repr
(
self
)
->
str
:
return
"p={}"
.
format
(
self
.
drop_prob
)
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.0
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
# x = self.drop(x)
# commit this for the orignal BERT implement
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
CosAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.0
,
proj_drop
=
0.0
,
attn_head_dim
=
None
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
if
attn_head_dim
is
not
None
:
head_dim
=
attn_head_dim
all_head_dim
=
head_dim
*
self
.
num_heads
# self.scale = qk_scale or head_dim**-0.5
# DO NOT RENAME [self.scale] (for no weight decay)
if
qk_scale
is
None
:
self
.
scale
=
nn
.
Parameter
(
torch
.
log
(
10
*
torch
.
ones
((
num_heads
,
1
,
1
))),
requires_grad
=
True
)
else
:
self
.
scale
=
qk_scale
self
.
qkv
=
nn
.
Linear
(
dim
,
all_head_dim
*
3
,
bias
=
False
)
if
qkv_bias
:
self
.
q_bias
=
nn
.
Parameter
(
torch
.
zeros
(
all_head_dim
))
self
.
v_bias
=
nn
.
Parameter
(
torch
.
zeros
(
all_head_dim
))
else
:
self
.
q_bias
=
None
self
.
v_bias
=
None
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
all_head_dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
qkv_bias
=
None
if
self
.
q_bias
is
not
None
:
qkv_bias
=
torch
.
cat
((
self
.
q_bias
,
torch
.
zeros_like
(
self
.
v_bias
,
requires_grad
=
False
),
self
.
v_bias
))
qkv
=
F
.
linear
(
input
=
x
,
weight
=
self
.
qkv
.
weight
,
bias
=
qkv_bias
)
qkv
=
qkv
.
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
-
1
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# make torchscript happy (cannot use tensor as tuple)
attn
=
F
.
normalize
(
q
,
dim
=-
1
)
@
F
.
normalize
(
k
,
dim
=-
1
).
transpose
(
-
2
,
-
1
)
# torch.log(torch.tensor(1. / 0.01)) = 4.6052
logit_scale
=
torch
.
clamp
(
self
.
scale
,
max
=
4.6052
).
exp
()
attn
=
attn
*
logit_scale
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
-
1
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.0
,
proj_drop
=
0.0
,
attn_head_dim
=
None
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
if
attn_head_dim
is
not
None
:
head_dim
=
attn_head_dim
all_head_dim
=
head_dim
*
self
.
num_heads
self
.
scale
=
qk_scale
or
head_dim
**-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
all_head_dim
*
3
,
bias
=
False
)
if
qkv_bias
:
self
.
q_bias
=
nn
.
Parameter
(
torch
.
zeros
(
all_head_dim
))
self
.
v_bias
=
nn
.
Parameter
(
torch
.
zeros
(
all_head_dim
))
else
:
self
.
q_bias
=
None
self
.
v_bias
=
None
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
all_head_dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
qkv_bias
=
None
if
self
.
q_bias
is
not
None
:
qkv_bias
=
torch
.
cat
((
self
.
q_bias
,
torch
.
zeros_like
(
self
.
v_bias
,
requires_grad
=
False
),
self
.
v_bias
))
qkv
=
F
.
linear
(
input
=
x
,
weight
=
self
.
qkv
.
weight
,
bias
=
qkv_bias
)
qkv
=
qkv
.
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
-
1
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# make torchscript happy (cannot use tensor as tuple)
q
=
q
*
self
.
scale
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
-
1
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.0
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.0
,
attn_drop
=
0.0
,
drop_path
=
0.0
,
init_values
=
None
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
attn_head_dim
=
None
,
cos_attn
=
False
,
):
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
if
cos_attn
:
self
.
attn
=
CosAttention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
attn_head_dim
=
attn_head_dim
,
)
else
:
self
.
attn
=
Attention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
attn_head_dim
=
attn_head_dim
,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
if
init_values
>
0
:
self
.
gamma_1
=
nn
.
Parameter
(
init_values
*
torch
.
ones
((
dim
)),
requires_grad
=
True
)
self
.
gamma_2
=
nn
.
Parameter
(
init_values
*
torch
.
ones
((
dim
)),
requires_grad
=
True
)
else
:
self
.
gamma_1
,
self
.
gamma_2
=
None
,
None
def
forward
(
self
,
x
):
if
self
.
gamma_1
is
None
:
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
)))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
else
:
x
=
x
+
self
.
drop_path
(
self
.
gamma_1
*
self
.
attn
(
self
.
norm1
(
x
)))
x
=
x
+
self
.
drop_path
(
self
.
gamma_2
*
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
class
PatchEmbed
(
nn
.
Module
):
"""Image to Patch Embedding"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
,
num_frames
=
16
,
tubelet_size
=
2
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
num_spatial_patches
=
(
img_size
[
0
]
//
patch_size
[
0
])
*
(
img_size
[
1
]
//
patch_size
[
1
])
num_patches
=
num_spatial_patches
*
(
num_frames
//
tubelet_size
)
self
.
img_size
=
img_size
self
.
tubelet_size
=
tubelet_size
self
.
patch_size
=
patch_size
self
.
num_patches
=
num_patches
self
.
proj
=
nn
.
Conv3d
(
in_channels
=
in_chans
,
out_channels
=
embed_dim
,
kernel_size
=
(
self
.
tubelet_size
,
patch_size
[
0
],
patch_size
[
1
]),
stride
=
(
self
.
tubelet_size
,
patch_size
[
0
],
patch_size
[
1
]),
)
def
forward
(
self
,
x
,
**
kwargs
):
B
,
C
,
T
,
H
,
W
=
x
.
shape
assert
(
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
]
),
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
# b, c, l -> b, l, c
# [1, 1408, 8, 16, 16] -> [1, 1408, 2048] -> [1, 2048, 1408]
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
return
x
# sin-cos position encoding
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
def
get_sinusoid_encoding_table
(
n_position
,
d_hid
):
"""Sinusoid position encoding table"""
# TODO: make it with torch instead of numpy
def
get_position_angle_vec
(
position
):
return
[
position
/
np
.
power
(
10000
,
2
*
(
hid_j
//
2
)
/
d_hid
)
for
hid_j
in
range
(
d_hid
)]
sinusoid_table
=
np
.
array
([
get_position_angle_vec
(
pos_i
)
for
pos_i
in
range
(
n_position
)])
sinusoid_table
[:,
0
::
2
]
=
np
.
sin
(
sinusoid_table
[:,
0
::
2
])
# dim 2i
sinusoid_table
[:,
1
::
2
]
=
np
.
cos
(
sinusoid_table
[:,
1
::
2
])
# dim 2i+1
return
torch
.
tensor
(
sinusoid_table
,
dtype
=
torch
.
float
,
requires_grad
=
False
).
unsqueeze
(
0
)
class
VisionTransformer
(
nn
.
Module
):
"""Vision Transformer with support for patch or hybrid CNN input stage"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
1000
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.0
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.0
,
attn_drop_rate
=
0.0
,
drop_path_rate
=
0.0
,
head_drop_rate
=
0.0
,
norm_layer
=
nn
.
LayerNorm
,
init_values
=
0.0
,
use_learnable_pos_emb
=
False
,
init_scale
=
0.0
,
all_frames
=
16
,
tubelet_size
=
2
,
use_mean_pooling
=
True
,
with_cp
=
False
,
cos_attn
=
False
,
):
super
().
__init__
()
self
.
num_classes
=
num_classes
# num_features for consistency with other models
self
.
num_features
=
self
.
embed_dim
=
embed_dim
self
.
tubelet_size
=
tubelet_size
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
num_frames
=
all_frames
,
tubelet_size
=
tubelet_size
,
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
with_cp
=
with_cp
if
use_learnable_pos_emb
:
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
embed_dim
))
else
:
# sine-cosine positional embeddings is on the way
self
.
pos_embed
=
get_sinusoid_encoding_table
(
num_patches
,
embed_dim
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)]
# stochastic depth decay rule
self
.
blocks
=
nn
.
ModuleList
(
[
Block
(
dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
i
],
norm_layer
=
norm_layer
,
init_values
=
init_values
,
cos_attn
=
cos_attn
,
)
for
i
in
range
(
depth
)
]
)
self
.
norm
=
nn
.
Identity
()
if
use_mean_pooling
else
norm_layer
(
embed_dim
)
self
.
fc_norm
=
norm_layer
(
embed_dim
)
if
use_mean_pooling
else
None
self
.
head_dropout
=
nn
.
Dropout
(
head_drop_rate
)
self
.
head
=
nn
.
Linear
(
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
if
use_learnable_pos_emb
:
trunc_normal_
(
self
.
pos_embed
,
std
=
0.02
)
self
.
apply
(
self
.
_init_weights
)
self
.
head
.
weight
.
data
.
mul_
(
init_scale
)
self
.
head
.
bias
.
data
.
mul_
(
init_scale
)
self
.
num_frames
=
all_frames
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
get_num_layers
(
self
):
return
len
(
self
.
blocks
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
"pos_embed"
,
"cls_token"
}
def
get_classifier
(
self
):
return
self
.
head
def
reset_classifier
(
self
,
num_classes
,
global_pool
=
""
):
self
.
num_classes
=
num_classes
self
.
head
=
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
def
interpolate_pos_encoding
(
self
,
t
):
T
=
8
t0
=
t
//
self
.
tubelet_size
if
T
==
t0
:
return
self
.
pos_embed
dim
=
self
.
pos_embed
.
shape
[
-
1
]
patch_pos_embed
=
self
.
pos_embed
.
permute
(
0
,
2
,
1
).
reshape
(
1
,
dim
,
8
,
16
,
16
)
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
t0
=
t0
+
0.1
patch_pos_embed
=
nn
.
functional
.
interpolate
(
patch_pos_embed
,
scale_factor
=
(
t0
/
T
,
1
,
1
),
mode
=
"trilinear"
,
)
assert
int
(
t0
)
==
patch_pos_embed
.
shape
[
-
3
]
patch_pos_embed
=
patch_pos_embed
.
reshape
(
1
,
dim
,
-
1
).
permute
(
0
,
2
,
1
)
return
patch_pos_embed
def
forward_features
(
self
,
x
):
# [1, 3, 16, 224, 224]
B
=
x
.
size
(
0
)
T
=
x
.
size
(
2
)
# [1, 2048, 1408]
x
=
self
.
patch_embed
(
x
)
if
self
.
pos_embed
is
not
None
:
x
=
x
+
self
.
interpolate_pos_encoding
(
T
).
expand
(
B
,
-
1
,
-
1
).
type_as
(
x
).
to
(
x
.
device
).
clone
().
detach
()
x
=
self
.
pos_drop
(
x
)
for
blk
in
self
.
blocks
:
if
self
.
with_cp
:
x
=
cp
.
checkpoint
(
blk
,
x
)
else
:
x
=
blk
(
x
)
# return self.fc_norm(x)
if
self
.
fc_norm
is
not
None
:
return
self
.
fc_norm
(
x
.
mean
(
1
))
else
:
return
self
.
norm
(
x
[:,
0
])
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
x
=
self
.
head_dropout
(
x
)
x
=
self
.
head
(
x
)
return
x
def
vit_giant_patch14_224
(
pretrained
=
False
,
**
kwargs
):
model
=
VisionTransformer
(
patch_size
=
14
,
embed_dim
=
1408
,
depth
=
40
,
num_heads
=
16
,
mlp_ratio
=
48
/
11
,
qkv_bias
=
True
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
,
)
model
.
default_cfg
=
_cfg
()
return
model
LatentSync/latentsync/trepa/third_party/VideoMAEv2/videomaev2_pretrain.py
0 → 100644
View file @
5c023842
# --------------------------------------------------------
# Based on BEiT, timm, DINO and DeiT code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
from
functools
import
partial
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
.videomaev2_finetune
import
(
Block
,
PatchEmbed
,
_cfg
,
get_sinusoid_encoding_table
,
)
from
.videomaev2_finetune
import
trunc_normal_
as
__call_trunc_normal_
def
trunc_normal_
(
tensor
,
mean
=
0.
,
std
=
1.
):
__call_trunc_normal_
(
tensor
,
mean
=
mean
,
std
=
std
,
a
=-
std
,
b
=
std
)
class
PretrainVisionTransformerEncoder
(
nn
.
Module
):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
0
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_layer
=
nn
.
LayerNorm
,
init_values
=
None
,
tubelet_size
=
2
,
use_learnable_pos_emb
=
False
,
with_cp
=
False
,
all_frames
=
16
,
cos_attn
=
False
):
super
().
__init__
()
self
.
num_classes
=
num_classes
# num_features for consistency with other models
self
.
num_features
=
self
.
embed_dim
=
embed_dim
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
num_frames
=
all_frames
,
tubelet_size
=
tubelet_size
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
with_cp
=
with_cp
if
use_learnable_pos_emb
:
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
1
,
embed_dim
))
else
:
# sine-cosine positional embeddings
self
.
pos_embed
=
get_sinusoid_encoding_table
(
num_patches
,
embed_dim
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)
]
# stochastic depth decay rule
self
.
blocks
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
i
],
norm_layer
=
norm_layer
,
init_values
=
init_values
,
cos_attn
=
cos_attn
)
for
i
in
range
(
depth
)
])
self
.
norm
=
norm_layer
(
embed_dim
)
self
.
head
=
nn
.
Linear
(
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
if
use_learnable_pos_emb
:
trunc_normal_
(
self
.
pos_embed
,
std
=
.
02
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
get_num_layers
(
self
):
return
len
(
self
.
blocks
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'pos_embed'
,
'cls_token'
}
def
get_classifier
(
self
):
return
self
.
head
def
reset_classifier
(
self
,
num_classes
,
global_pool
=
''
):
self
.
num_classes
=
num_classes
self
.
head
=
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
def
forward_features
(
self
,
x
,
mask
):
x
=
self
.
patch_embed
(
x
)
x
=
x
+
self
.
pos_embed
.
type_as
(
x
).
to
(
x
.
device
).
clone
().
detach
()
B
,
_
,
C
=
x
.
shape
x_vis
=
x
[
~
mask
].
reshape
(
B
,
-
1
,
C
)
# ~mask means visible
for
blk
in
self
.
blocks
:
if
self
.
with_cp
:
x_vis
=
cp
.
checkpoint
(
blk
,
x_vis
)
else
:
x_vis
=
blk
(
x_vis
)
x_vis
=
self
.
norm
(
x_vis
)
return
x_vis
def
forward
(
self
,
x
,
mask
):
x
=
self
.
forward_features
(
x
,
mask
)
x
=
self
.
head
(
x
)
return
x
class
PretrainVisionTransformerDecoder
(
nn
.
Module
):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def
__init__
(
self
,
patch_size
=
16
,
num_classes
=
768
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_layer
=
nn
.
LayerNorm
,
init_values
=
None
,
num_patches
=
196
,
tubelet_size
=
2
,
with_cp
=
False
,
cos_attn
=
False
):
super
().
__init__
()
self
.
num_classes
=
num_classes
assert
num_classes
==
3
*
tubelet_size
*
patch_size
**
2
# num_features for consistency with other models
self
.
num_features
=
self
.
embed_dim
=
embed_dim
self
.
patch_size
=
patch_size
self
.
with_cp
=
with_cp
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)
]
# stochastic depth decay rule
self
.
blocks
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
i
],
norm_layer
=
norm_layer
,
init_values
=
init_values
,
cos_attn
=
cos_attn
)
for
i
in
range
(
depth
)
])
self
.
norm
=
norm_layer
(
embed_dim
)
self
.
head
=
nn
.
Linear
(
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
get_num_layers
(
self
):
return
len
(
self
.
blocks
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'pos_embed'
,
'cls_token'
}
def
get_classifier
(
self
):
return
self
.
head
def
reset_classifier
(
self
,
num_classes
,
global_pool
=
''
):
self
.
num_classes
=
num_classes
self
.
head
=
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
def
forward
(
self
,
x
,
return_token_num
):
for
blk
in
self
.
blocks
:
if
self
.
with_cp
:
x
=
cp
.
checkpoint
(
blk
,
x
)
else
:
x
=
blk
(
x
)
if
return_token_num
>
0
:
# only return the mask tokens predict pixels
x
=
self
.
head
(
self
.
norm
(
x
[:,
-
return_token_num
:]))
else
:
# [B, N, 3*16^2]
x
=
self
.
head
(
self
.
norm
(
x
))
return
x
class
PretrainVisionTransformer
(
nn
.
Module
):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
encoder_in_chans
=
3
,
encoder_num_classes
=
0
,
encoder_embed_dim
=
768
,
encoder_depth
=
12
,
encoder_num_heads
=
12
,
decoder_num_classes
=
1536
,
# decoder_num_classes=768
decoder_embed_dim
=
512
,
decoder_depth
=
8
,
decoder_num_heads
=
8
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_layer
=
nn
.
LayerNorm
,
init_values
=
0.
,
use_learnable_pos_emb
=
False
,
tubelet_size
=
2
,
num_classes
=
0
,
# avoid the error from create_fn in timm
in_chans
=
0
,
# avoid the error from create_fn in timm
with_cp
=
False
,
all_frames
=
16
,
cos_attn
=
False
,
):
super
().
__init__
()
self
.
encoder
=
PretrainVisionTransformerEncoder
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
encoder_in_chans
,
num_classes
=
encoder_num_classes
,
embed_dim
=
encoder_embed_dim
,
depth
=
encoder_depth
,
num_heads
=
encoder_num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop_rate
=
drop_rate
,
attn_drop_rate
=
attn_drop_rate
,
drop_path_rate
=
drop_path_rate
,
norm_layer
=
norm_layer
,
init_values
=
init_values
,
tubelet_size
=
tubelet_size
,
use_learnable_pos_emb
=
use_learnable_pos_emb
,
with_cp
=
with_cp
,
all_frames
=
all_frames
,
cos_attn
=
cos_attn
)
self
.
decoder
=
PretrainVisionTransformerDecoder
(
patch_size
=
patch_size
,
num_patches
=
self
.
encoder
.
patch_embed
.
num_patches
,
num_classes
=
decoder_num_classes
,
embed_dim
=
decoder_embed_dim
,
depth
=
decoder_depth
,
num_heads
=
decoder_num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop_rate
=
drop_rate
,
attn_drop_rate
=
attn_drop_rate
,
drop_path_rate
=
drop_path_rate
,
norm_layer
=
norm_layer
,
init_values
=
init_values
,
tubelet_size
=
tubelet_size
,
with_cp
=
with_cp
,
cos_attn
=
cos_attn
)
self
.
encoder_to_decoder
=
nn
.
Linear
(
encoder_embed_dim
,
decoder_embed_dim
,
bias
=
False
)
self
.
mask_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
decoder_embed_dim
))
self
.
pos_embed
=
get_sinusoid_encoding_table
(
self
.
encoder
.
patch_embed
.
num_patches
,
decoder_embed_dim
)
trunc_normal_
(
self
.
mask_token
,
std
=
.
02
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
get_num_layers
(
self
):
return
len
(
self
.
blocks
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'pos_embed'
,
'cls_token'
,
'mask_token'
}
def
forward
(
self
,
x
,
mask
,
decode_mask
=
None
):
decode_vis
=
mask
if
decode_mask
is
None
else
~
decode_mask
x_vis
=
self
.
encoder
(
x
,
mask
)
# [B, N_vis, C_e]
x_vis
=
self
.
encoder_to_decoder
(
x_vis
)
# [B, N_vis, C_d]
B
,
N_vis
,
C
=
x_vis
.
shape
# we don't unshuffle the correct visible token order,
# but shuffle the pos embedding accorddingly.
expand_pos_embed
=
self
.
pos_embed
.
expand
(
B
,
-
1
,
-
1
).
type_as
(
x
).
to
(
x
.
device
).
clone
().
detach
()
pos_emd_vis
=
expand_pos_embed
[
~
mask
].
reshape
(
B
,
-
1
,
C
)
pos_emd_mask
=
expand_pos_embed
[
decode_vis
].
reshape
(
B
,
-
1
,
C
)
# [B, N, C_d]
x_full
=
torch
.
cat
(
[
x_vis
+
pos_emd_vis
,
self
.
mask_token
+
pos_emd_mask
],
dim
=
1
)
# NOTE: if N_mask==0, the shape of x is [B, N_mask, 3 * 16 * 16]
x
=
self
.
decoder
(
x_full
,
pos_emd_mask
.
shape
[
1
])
return
x
def
pretrain_videomae_small_patch16_224
(
pretrained
=
False
,
**
kwargs
):
model
=
PretrainVisionTransformer
(
img_size
=
224
,
patch_size
=
16
,
encoder_embed_dim
=
384
,
encoder_depth
=
12
,
encoder_num_heads
=
6
,
encoder_num_classes
=
0
,
decoder_num_classes
=
1536
,
# 16 * 16 * 3 * 2
decoder_embed_dim
=
192
,
decoder_num_heads
=
3
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
)
model
.
default_cfg
=
_cfg
()
if
pretrained
:
checkpoint
=
torch
.
load
(
kwargs
[
"init_ckpt"
],
map_location
=
"cpu"
)
model
.
load_state_dict
(
checkpoint
[
"model"
])
return
model
def
pretrain_videomae_base_patch16_224
(
pretrained
=
False
,
**
kwargs
):
model
=
PretrainVisionTransformer
(
img_size
=
224
,
patch_size
=
16
,
encoder_embed_dim
=
768
,
encoder_depth
=
12
,
encoder_num_heads
=
12
,
encoder_num_classes
=
0
,
decoder_num_classes
=
1536
,
# 16 * 16 * 3 * 2
decoder_embed_dim
=
384
,
decoder_num_heads
=
6
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
)
model
.
default_cfg
=
_cfg
()
if
pretrained
:
checkpoint
=
torch
.
load
(
kwargs
[
"init_ckpt"
],
map_location
=
"cpu"
)
model
.
load_state_dict
(
checkpoint
[
"model"
])
return
model
def
pretrain_videomae_large_patch16_224
(
pretrained
=
False
,
**
kwargs
):
model
=
PretrainVisionTransformer
(
img_size
=
224
,
patch_size
=
16
,
encoder_embed_dim
=
1024
,
encoder_depth
=
24
,
encoder_num_heads
=
16
,
encoder_num_classes
=
0
,
decoder_num_classes
=
1536
,
# 16 * 16 * 3 * 2
decoder_embed_dim
=
512
,
decoder_num_heads
=
8
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
)
model
.
default_cfg
=
_cfg
()
if
pretrained
:
checkpoint
=
torch
.
load
(
kwargs
[
"init_ckpt"
],
map_location
=
"cpu"
)
model
.
load_state_dict
(
checkpoint
[
"model"
])
return
model
def
pretrain_videomae_huge_patch16_224
(
pretrained
=
False
,
**
kwargs
):
model
=
PretrainVisionTransformer
(
img_size
=
224
,
patch_size
=
16
,
encoder_embed_dim
=
1280
,
encoder_depth
=
32
,
encoder_num_heads
=
16
,
encoder_num_classes
=
0
,
decoder_num_classes
=
1536
,
# 16 * 16 * 3 * 2
decoder_embed_dim
=
512
,
decoder_num_heads
=
8
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
)
model
.
default_cfg
=
_cfg
()
if
pretrained
:
checkpoint
=
torch
.
load
(
kwargs
[
"init_ckpt"
],
map_location
=
"cpu"
)
model
.
load_state_dict
(
checkpoint
[
"model"
])
return
model
def
pretrain_videomae_giant_patch14_224
(
pretrained
=
False
,
**
kwargs
):
model
=
PretrainVisionTransformer
(
img_size
=
224
,
patch_size
=
14
,
encoder_embed_dim
=
1408
,
encoder_depth
=
40
,
encoder_num_heads
=
16
,
encoder_num_classes
=
0
,
decoder_num_classes
=
1176
,
# 14 * 14 * 3 * 2,
decoder_embed_dim
=
512
,
decoder_num_heads
=
8
,
mlp_ratio
=
48
/
11
,
qkv_bias
=
True
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
)
model
.
default_cfg
=
_cfg
()
if
pretrained
:
checkpoint
=
torch
.
load
(
kwargs
[
"init_ckpt"
],
map_location
=
"cpu"
)
model
.
load_state_dict
(
checkpoint
[
"model"
])
return
model
LatentSync/latentsync/trepa/third_party/__init__.py
0 → 100644
View file @
5c023842
LatentSync/latentsync/trepa/utils/__init__.py
0 → 100644
View file @
5c023842
LatentSync/latentsync/trepa/utils/data_utils.py
0 → 100644
View file @
5c023842
import
os
import
math
import
os.path
as
osp
import
random
import
pickle
import
warnings
import
glob
import
numpy
as
np
from
PIL
import
Image
import
torch
import
torch.utils.data
as
data
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
from
torchvision.datasets.video_utils
import
VideoClips
IMG_EXTENSIONS
=
[
'.jpg'
,
'.JPG'
,
'.jpeg'
,
'.JPEG'
,
'.png'
,
'.PNG'
]
VID_EXTENSIONS
=
[
'.avi'
,
'.mp4'
,
'.webm'
,
'.mov'
,
'.mkv'
,
'.m4v'
]
def
get_dataloader
(
data_path
,
image_folder
,
resolution
=
128
,
sequence_length
=
16
,
sample_every_n_frames
=
1
,
batch_size
=
16
,
num_workers
=
8
):
data
=
VideoData
(
data_path
,
image_folder
,
resolution
,
sequence_length
,
sample_every_n_frames
,
batch_size
,
num_workers
)
loader
=
data
.
_dataloader
()
return
loader
def
is_image_file
(
filename
):
return
any
(
filename
.
endswith
(
extension
)
for
extension
in
IMG_EXTENSIONS
)
def
get_parent_dir
(
path
):
return
osp
.
basename
(
osp
.
dirname
(
path
))
def
preprocess
(
video
,
resolution
,
sequence_length
=
None
,
in_channels
=
3
,
sample_every_n_frames
=
1
):
# video: THWC, {0, ..., 255}
assert
in_channels
==
3
video
=
video
.
permute
(
0
,
3
,
1
,
2
).
float
()
/
255.
# TCHW
t
,
c
,
h
,
w
=
video
.
shape
# temporal crop
if
sequence_length
is
not
None
:
assert
sequence_length
<=
t
video
=
video
[:
sequence_length
]
# skip frames
if
sample_every_n_frames
>
1
:
video
=
video
[::
sample_every_n_frames
]
# scale shorter side to resolution
scale
=
resolution
/
min
(
h
,
w
)
if
h
<
w
:
target_size
=
(
resolution
,
math
.
ceil
(
w
*
scale
))
else
:
target_size
=
(
math
.
ceil
(
h
*
scale
),
resolution
)
video
=
F
.
interpolate
(
video
,
size
=
target_size
,
mode
=
'bilinear'
,
align_corners
=
False
,
antialias
=
True
)
# center crop
t
,
c
,
h
,
w
=
video
.
shape
w_start
=
(
w
-
resolution
)
//
2
h_start
=
(
h
-
resolution
)
//
2
video
=
video
[:,
:,
h_start
:
h_start
+
resolution
,
w_start
:
w_start
+
resolution
]
video
=
video
.
permute
(
1
,
0
,
2
,
3
).
contiguous
()
# CTHW
return
{
'video'
:
video
}
def
preprocess_image
(
image
):
# [0, 1] => [-1, 1]
img
=
torch
.
from_numpy
(
image
)
return
img
class
VideoData
(
data
.
Dataset
):
""" Class to create dataloaders for video datasets
Args:
data_path: Path to the folder with video frames or videos.
image_folder: If True, the data is stored as images in folders.
resolution: Resolution of the returned videos.
sequence_length: Length of extracted video sequences.
sample_every_n_frames: Sample every n frames from the video.
batch_size: Batch size.
num_workers: Number of workers for the dataloader.
shuffle: If True, shuffle the data.
"""
def
__init__
(
self
,
data_path
:
str
,
image_folder
:
bool
,
resolution
:
int
,
sequence_length
:
int
,
sample_every_n_frames
:
int
,
batch_size
:
int
,
num_workers
:
int
,
shuffle
:
bool
=
True
):
super
().
__init__
()
self
.
data_path
=
data_path
self
.
image_folder
=
image_folder
self
.
resolution
=
resolution
self
.
sequence_length
=
sequence_length
self
.
sample_every_n_frames
=
sample_every_n_frames
self
.
batch_size
=
batch_size
self
.
num_workers
=
num_workers
self
.
shuffle
=
shuffle
def
_dataset
(
self
):
'''
Initializes and return the dataset.
'''
if
self
.
image_folder
:
Dataset
=
FrameDataset
dataset
=
Dataset
(
self
.
data_path
,
self
.
sequence_length
,
resolution
=
self
.
resolution
,
sample_every_n_frames
=
self
.
sample_every_n_frames
)
else
:
Dataset
=
VideoDataset
dataset
=
Dataset
(
self
.
data_path
,
self
.
sequence_length
,
resolution
=
self
.
resolution
,
sample_every_n_frames
=
self
.
sample_every_n_frames
)
return
dataset
def
_dataloader
(
self
):
'''
Initializes and returns the dataloader.
'''
dataset
=
self
.
_dataset
()
if
dist
.
is_initialized
():
sampler
=
data
.
distributed
.
DistributedSampler
(
dataset
,
num_replicas
=
dist
.
get_world_size
(),
rank
=
dist
.
get_rank
()
)
else
:
sampler
=
None
dataloader
=
data
.
DataLoader
(
dataset
,
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
pin_memory
=
True
,
sampler
=
sampler
,
shuffle
=
sampler
is
None
and
self
.
shuffle
is
True
)
return
dataloader
class
VideoDataset
(
data
.
Dataset
):
"""
Generic dataset for videos files stored in folders.
Videos of the same class are expected to be stored in a single folder. Multiple folders can exist in the provided directory.
The class depends on `torchvision.datasets.video_utils.VideoClips` to load the videos.
Returns BCTHW videos in the range [0, 1].
Args:
data_folder: Path to the folder with corresponding videos stored.
sequence_length: Length of extracted video sequences.
resolution: Resolution of the returned videos.
sample_every_n_frames: Sample every n frames from the video.
"""
def
__init__
(
self
,
data_folder
:
str
,
sequence_length
:
int
=
16
,
resolution
:
int
=
128
,
sample_every_n_frames
:
int
=
1
):
super
().
__init__
()
self
.
sequence_length
=
sequence_length
self
.
resolution
=
resolution
self
.
sample_every_n_frames
=
sample_every_n_frames
folder
=
data_folder
files
=
sum
([
glob
.
glob
(
osp
.
join
(
folder
,
'**'
,
f
'*
{
ext
}
'
),
recursive
=
True
)
for
ext
in
VID_EXTENSIONS
],
[])
warnings
.
filterwarnings
(
'ignore'
)
cache_file
=
osp
.
join
(
folder
,
f
"metadata_
{
sequence_length
}
.pkl"
)
if
not
osp
.
exists
(
cache_file
):
clips
=
VideoClips
(
files
,
sequence_length
,
num_workers
=
4
)
try
:
pickle
.
dump
(
clips
.
metadata
,
open
(
cache_file
,
'wb'
))
except
:
print
(
f
"Failed to save metadata to
{
cache_file
}
"
)
else
:
metadata
=
pickle
.
load
(
open
(
cache_file
,
'rb'
))
clips
=
VideoClips
(
files
,
sequence_length
,
_precomputed_metadata
=
metadata
)
self
.
_clips
=
clips
# instead of uniformly sampling from all possible clips, we sample uniformly from all possible videos
self
.
_clips
.
get_clip_location
=
self
.
get_random_clip_from_video
def
get_random_clip_from_video
(
self
,
idx
:
int
)
->
tuple
:
'''
Sample a random clip starting index from the video.
Args:
idx: Index of the video.
'''
# Note that some videos may not contain enough frames, we skip those videos here.
while
self
.
_clips
.
clips
[
idx
].
shape
[
0
]
<=
0
:
idx
+=
1
n_clip
=
self
.
_clips
.
clips
[
idx
].
shape
[
0
]
clip_id
=
random
.
randint
(
0
,
n_clip
-
1
)
return
idx
,
clip_id
def
__len__
(
self
):
return
self
.
_clips
.
num_videos
()
def
__getitem__
(
self
,
idx
):
resolution
=
self
.
resolution
while
True
:
try
:
video
,
_
,
_
,
idx
=
self
.
_clips
.
get_clip
(
idx
)
except
Exception
as
e
:
print
(
idx
,
e
)
idx
=
(
idx
+
1
)
%
self
.
_clips
.
num_clips
()
continue
break
return
dict
(
**
preprocess
(
video
,
resolution
,
sample_every_n_frames
=
self
.
sample_every_n_frames
))
class
FrameDataset
(
data
.
Dataset
):
"""
Generic dataset for videos stored as images. The loading will iterates over all the folders and subfolders
in the provided directory. Each leaf folder is assumed to contain frames from a single video.
Args:
data_folder: path to the folder with video frames. The folder
should contain folders with frames from each video.
sequence_length: length of extracted video sequences
resolution: resolution of the returned videos
sample_every_n_frames: sample every n frames from the video
"""
def
__init__
(
self
,
data_folder
,
sequence_length
,
resolution
=
64
,
sample_every_n_frames
=
1
):
self
.
resolution
=
resolution
self
.
sequence_length
=
sequence_length
self
.
sample_every_n_frames
=
sample_every_n_frames
self
.
data_all
=
self
.
load_video_frames
(
data_folder
)
self
.
video_num
=
len
(
self
.
data_all
)
def
__getitem__
(
self
,
index
):
batch_data
=
self
.
getTensor
(
index
)
return_list
=
{
'video'
:
batch_data
}
return
return_list
def
load_video_frames
(
self
,
dataroot
:
str
)
->
list
:
'''
Loads all the video frames under the dataroot and returns a list of all the video frames.
Args:
dataroot: The root directory containing the video frames.
Returns:
A list of all the video frames.
'''
data_all
=
[]
frame_list
=
os
.
walk
(
dataroot
)
for
_
,
meta
in
enumerate
(
frame_list
):
root
=
meta
[
0
]
try
:
frames
=
sorted
(
meta
[
2
],
key
=
lambda
item
:
int
(
item
.
split
(
'.'
)[
0
].
split
(
'_'
)[
-
1
]))
except
:
print
(
meta
[
0
],
meta
[
2
])
if
len
(
frames
)
<
max
(
0
,
self
.
sequence_length
*
self
.
sample_every_n_frames
):
continue
frames
=
[
os
.
path
.
join
(
root
,
item
)
for
item
in
frames
if
is_image_file
(
item
)
]
if
len
(
frames
)
>
max
(
0
,
self
.
sequence_length
*
self
.
sample_every_n_frames
):
data_all
.
append
(
frames
)
return
data_all
def
getTensor
(
self
,
index
:
int
)
->
torch
.
Tensor
:
'''
Returns a tensor of the video frames at the given index.
Args:
index: The index of the video frames to return.
Returns:
A BCTHW tensor in the range `[0, 1]` of the video frames at the given index.
'''
video
=
self
.
data_all
[
index
]
video_len
=
len
(
video
)
# load the entire video when sequence_length = -1, whiel the sample_every_n_frames has to be 1
if
self
.
sequence_length
==
-
1
:
assert
self
.
sample_every_n_frames
==
1
start_idx
=
0
end_idx
=
video_len
else
:
n_frames_interval
=
self
.
sequence_length
*
self
.
sample_every_n_frames
start_idx
=
random
.
randint
(
0
,
video_len
-
n_frames_interval
)
end_idx
=
start_idx
+
n_frames_interval
img
=
Image
.
open
(
video
[
0
])
h
,
w
=
img
.
height
,
img
.
width
if
h
>
w
:
half
=
(
h
-
w
)
//
2
cropsize
=
(
0
,
half
,
w
,
half
+
w
)
# left, upper, right, lower
elif
w
>
h
:
half
=
(
w
-
h
)
//
2
cropsize
=
(
half
,
0
,
half
+
h
,
h
)
images
=
[]
for
i
in
range
(
start_idx
,
end_idx
,
self
.
sample_every_n_frames
):
path
=
video
[
i
]
img
=
Image
.
open
(
path
)
if
h
!=
w
:
img
=
img
.
crop
(
cropsize
)
img
=
img
.
resize
(
(
self
.
resolution
,
self
.
resolution
),
Image
.
ANTIALIAS
)
img
=
np
.
asarray
(
img
,
dtype
=
np
.
float32
)
img
/=
255.
img_tensor
=
preprocess_image
(
img
).
unsqueeze
(
0
)
images
.
append
(
img_tensor
)
video_clip
=
torch
.
cat
(
images
).
permute
(
3
,
0
,
1
,
2
)
return
video_clip
def
__len__
(
self
):
return
self
.
video_num
LatentSync/latentsync/trepa/utils/metric_utils.py
0 → 100644
View file @
5c023842
# Adapted from https://github.com/universome/stylegan-v/blob/master/src/metrics/metric_utils.py
import
os
import
random
import
torch
import
pickle
import
numpy
as
np
from
typing
import
List
,
Tuple
def
seed_everything
(
seed
):
random
.
seed
(
seed
)
os
.
environ
[
'PYTHONHASHSEED'
]
=
str
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
class
FeatureStats
:
'''
Class to store statistics of features, including all features and mean/covariance.
Args:
capture_all: Whether to store all the features.
capture_mean_cov: Whether to store mean and covariance.
max_items: Maximum number of items to store.
'''
def
__init__
(
self
,
capture_all
:
bool
=
False
,
capture_mean_cov
:
bool
=
False
,
max_items
:
int
=
None
):
'''
'''
self
.
capture_all
=
capture_all
self
.
capture_mean_cov
=
capture_mean_cov
self
.
max_items
=
max_items
self
.
num_items
=
0
self
.
num_features
=
None
self
.
all_features
=
None
self
.
raw_mean
=
None
self
.
raw_cov
=
None
def
set_num_features
(
self
,
num_features
:
int
):
'''
Set the number of features diminsions.
Args:
num_features: Number of features diminsions.
'''
if
self
.
num_features
is
not
None
:
assert
num_features
==
self
.
num_features
else
:
self
.
num_features
=
num_features
self
.
all_features
=
[]
self
.
raw_mean
=
np
.
zeros
([
num_features
],
dtype
=
np
.
float64
)
self
.
raw_cov
=
np
.
zeros
([
num_features
,
num_features
],
dtype
=
np
.
float64
)
def
is_full
(
self
)
->
bool
:
'''
Check if the maximum number of samples is reached.
Returns:
True if the storage is full, False otherwise.
'''
return
(
self
.
max_items
is
not
None
)
and
(
self
.
num_items
>=
self
.
max_items
)
def
append
(
self
,
x
:
np
.
ndarray
):
'''
Add the newly computed features to the list. Update the mean and covariance.
Args:
x: New features to record.
'''
x
=
np
.
asarray
(
x
,
dtype
=
np
.
float32
)
assert
x
.
ndim
==
2
if
(
self
.
max_items
is
not
None
)
and
(
self
.
num_items
+
x
.
shape
[
0
]
>
self
.
max_items
):
if
self
.
num_items
>=
self
.
max_items
:
return
x
=
x
[:
self
.
max_items
-
self
.
num_items
]
self
.
set_num_features
(
x
.
shape
[
1
])
self
.
num_items
+=
x
.
shape
[
0
]
if
self
.
capture_all
:
self
.
all_features
.
append
(
x
)
if
self
.
capture_mean_cov
:
x64
=
x
.
astype
(
np
.
float64
)
self
.
raw_mean
+=
x64
.
sum
(
axis
=
0
)
self
.
raw_cov
+=
x64
.
T
@
x64
def
append_torch
(
self
,
x
:
torch
.
Tensor
,
rank
:
int
,
num_gpus
:
int
):
'''
Add the newly computed PyTorch features to the list. Update the mean and covariance.
Args:
x: New features to record.
rank: Rank of the current GPU.
num_gpus: Total number of GPUs.
'''
assert
isinstance
(
x
,
torch
.
Tensor
)
and
x
.
ndim
==
2
assert
0
<=
rank
<
num_gpus
if
num_gpus
>
1
:
ys
=
[]
for
src
in
range
(
num_gpus
):
y
=
x
.
clone
()
torch
.
distributed
.
broadcast
(
y
,
src
=
src
)
ys
.
append
(
y
)
x
=
torch
.
stack
(
ys
,
dim
=
1
).
flatten
(
0
,
1
)
# interleave samples
self
.
append
(
x
.
cpu
().
numpy
())
def
get_all
(
self
)
->
np
.
ndarray
:
'''
Get all the stored features as NumPy Array.
Returns:
Concatenation of the stored features.
'''
assert
self
.
capture_all
return
np
.
concatenate
(
self
.
all_features
,
axis
=
0
)
def
get_all_torch
(
self
)
->
torch
.
Tensor
:
'''
Get all the stored features as PyTorch Tensor.
Returns:
Concatenation of the stored features.
'''
return
torch
.
from_numpy
(
self
.
get_all
())
def
get_mean_cov
(
self
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
'''
Get the mean and covariance of the stored features.
Returns:
Mean and covariance of the stored features.
'''
assert
self
.
capture_mean_cov
mean
=
self
.
raw_mean
/
self
.
num_items
cov
=
self
.
raw_cov
/
self
.
num_items
cov
=
cov
-
np
.
outer
(
mean
,
mean
)
return
mean
,
cov
def
save
(
self
,
pkl_file
:
str
):
'''
Save the features and statistics to a pickle file.
Args:
pkl_file: Path to the pickle file.
'''
with
open
(
pkl_file
,
'wb'
)
as
f
:
pickle
.
dump
(
self
.
__dict__
,
f
)
@
staticmethod
def
load
(
pkl_file
:
str
)
->
'FeatureStats'
:
'''
Load the features and statistics from a pickle file.
Args:
pkl_file: Path to the pickle file.
'''
with
open
(
pkl_file
,
'rb'
)
as
f
:
s
=
pickle
.
load
(
f
)
obj
=
FeatureStats
(
capture_all
=
s
[
'capture_all'
],
max_items
=
s
[
'max_items'
])
obj
.
__dict__
.
update
(
s
)
print
(
'Loaded %d features from %s'
%
(
obj
.
num_items
,
pkl_file
))
return
obj
LatentSync/latentsync/utils/affine_transform.py
0 → 100644
View file @
5c023842
# Adapted from https://github.com/guanjz20/StyleSync/blob/main/utils.py
import
numpy
as
np
import
cv2
def
transformation_from_points
(
points1
,
points0
,
smooth
=
True
,
p_bias
=
None
):
points2
=
np
.
array
(
points0
)
points2
=
points2
.
astype
(
np
.
float64
)
points1
=
points1
.
astype
(
np
.
float64
)
c1
=
np
.
mean
(
points1
,
axis
=
0
)
c2
=
np
.
mean
(
points2
,
axis
=
0
)
points1
-=
c1
points2
-=
c2
s1
=
np
.
std
(
points1
)
s2
=
np
.
std
(
points2
)
points1
/=
s1
points2
/=
s2
U
,
S
,
Vt
=
np
.
linalg
.
svd
(
np
.
matmul
(
points1
.
T
,
points2
))
R
=
(
np
.
matmul
(
U
,
Vt
)).
T
sR
=
(
s2
/
s1
)
*
R
T
=
c2
.
reshape
(
2
,
1
)
-
(
s2
/
s1
)
*
np
.
matmul
(
R
,
c1
.
reshape
(
2
,
1
))
M
=
np
.
concatenate
((
sR
,
T
),
axis
=
1
)
if
smooth
:
bias
=
points2
[
2
]
-
points1
[
2
]
if
p_bias
is
None
:
p_bias
=
bias
else
:
bias
=
p_bias
*
0.2
+
bias
*
0.8
p_bias
=
bias
M
[:,
2
]
=
M
[:,
2
]
+
bias
return
M
,
p_bias
class
AlignRestore
(
object
):
def
__init__
(
self
,
align_points
=
3
):
if
align_points
==
3
:
self
.
upscale_factor
=
1
self
.
crop_ratio
=
(
2.8
,
2.8
)
self
.
face_template
=
np
.
array
([[
19
-
2
,
30
-
10
],
[
56
+
2
,
30
-
10
],
[
37.5
,
45
-
5
]])
self
.
face_template
=
self
.
face_template
*
2.8
# self.face_size = (int(100 * self.crop_ratio[0]), int(100 * self.crop_ratio[1]))
self
.
face_size
=
(
int
(
75
*
self
.
crop_ratio
[
0
]),
int
(
100
*
self
.
crop_ratio
[
1
]))
self
.
p_bias
=
None
def
process
(
self
,
img
,
lmk_align
=
None
,
smooth
=
True
,
align_points
=
3
):
aligned_face
,
affine_matrix
=
self
.
align_warp_face
(
img
,
lmk_align
,
smooth
)
restored_img
=
self
.
restore_img
(
img
,
aligned_face
,
affine_matrix
)
cv2
.
imwrite
(
"restored.jpg"
,
restored_img
)
cv2
.
imwrite
(
"aligned.jpg"
,
aligned_face
)
return
aligned_face
,
restored_img
def
align_warp_face
(
self
,
img
,
lmks3
,
smooth
=
True
,
border_mode
=
"constant"
):
affine_matrix
,
self
.
p_bias
=
transformation_from_points
(
lmks3
,
self
.
face_template
,
smooth
,
self
.
p_bias
)
if
border_mode
==
"constant"
:
border_mode
=
cv2
.
BORDER_CONSTANT
elif
border_mode
==
"reflect101"
:
border_mode
=
cv2
.
BORDER_REFLECT101
elif
border_mode
==
"reflect"
:
border_mode
=
cv2
.
BORDER_REFLECT
cropped_face
=
cv2
.
warpAffine
(
img
,
affine_matrix
,
self
.
face_size
,
borderMode
=
border_mode
,
borderValue
=
[
127
,
127
,
127
]
)
return
cropped_face
,
affine_matrix
def
align_warp_face2
(
self
,
img
,
landmark
,
border_mode
=
"constant"
):
affine_matrix
=
cv2
.
estimateAffinePartial2D
(
landmark
,
self
.
face_template
)[
0
]
if
border_mode
==
"constant"
:
border_mode
=
cv2
.
BORDER_CONSTANT
elif
border_mode
==
"reflect101"
:
border_mode
=
cv2
.
BORDER_REFLECT101
elif
border_mode
==
"reflect"
:
border_mode
=
cv2
.
BORDER_REFLECT
cropped_face
=
cv2
.
warpAffine
(
img
,
affine_matrix
,
self
.
face_size
,
borderMode
=
border_mode
,
borderValue
=
(
135
,
133
,
132
)
)
return
cropped_face
,
affine_matrix
def
restore_img
(
self
,
input_img
,
face
,
affine_matrix
):
h
,
w
,
_
=
input_img
.
shape
h_up
,
w_up
=
int
(
h
*
self
.
upscale_factor
),
int
(
w
*
self
.
upscale_factor
)
upsample_img
=
cv2
.
resize
(
input_img
,
(
w_up
,
h_up
),
interpolation
=
cv2
.
INTER_LANCZOS4
)
inverse_affine
=
cv2
.
invertAffineTransform
(
affine_matrix
)
inverse_affine
*=
self
.
upscale_factor
if
self
.
upscale_factor
>
1
:
extra_offset
=
0.5
*
self
.
upscale_factor
else
:
extra_offset
=
0
inverse_affine
[:,
2
]
+=
extra_offset
inv_restored
=
cv2
.
warpAffine
(
face
,
inverse_affine
,
(
w_up
,
h_up
))
mask
=
np
.
ones
((
self
.
face_size
[
1
],
self
.
face_size
[
0
]),
dtype
=
np
.
float32
)
inv_mask
=
cv2
.
warpAffine
(
mask
,
inverse_affine
,
(
w_up
,
h_up
))
inv_mask_erosion
=
cv2
.
erode
(
inv_mask
,
np
.
ones
((
int
(
2
*
self
.
upscale_factor
),
int
(
2
*
self
.
upscale_factor
)),
np
.
uint8
)
)
pasted_face
=
inv_mask_erosion
[:,
:,
None
]
*
inv_restored
total_face_area
=
np
.
sum
(
inv_mask_erosion
)
w_edge
=
int
(
total_face_area
**
0.5
)
//
20
erosion_radius
=
w_edge
*
2
inv_mask_center
=
cv2
.
erode
(
inv_mask_erosion
,
np
.
ones
((
erosion_radius
,
erosion_radius
),
np
.
uint8
))
blur_size
=
w_edge
*
2
inv_soft_mask
=
cv2
.
GaussianBlur
(
inv_mask_center
,
(
blur_size
+
1
,
blur_size
+
1
),
0
)
inv_soft_mask
=
inv_soft_mask
[:,
:,
None
]
upsample_img
=
inv_soft_mask
*
pasted_face
+
(
1
-
inv_soft_mask
)
*
upsample_img
if
np
.
max
(
upsample_img
)
>
256
:
upsample_img
=
upsample_img
.
astype
(
np
.
uint16
)
else
:
upsample_img
=
upsample_img
.
astype
(
np
.
uint8
)
return
upsample_img
class
laplacianSmooth
:
def
__init__
(
self
,
smoothAlpha
=
0.3
):
self
.
smoothAlpha
=
smoothAlpha
self
.
pts_last
=
None
def
smooth
(
self
,
pts_cur
):
if
self
.
pts_last
is
None
:
self
.
pts_last
=
pts_cur
.
copy
()
return
pts_cur
.
copy
()
x1
=
min
(
pts_cur
[:,
0
])
x2
=
max
(
pts_cur
[:,
0
])
y1
=
min
(
pts_cur
[:,
1
])
y2
=
max
(
pts_cur
[:,
1
])
width
=
x2
-
x1
pts_update
=
[]
for
i
in
range
(
len
(
pts_cur
)):
x_new
,
y_new
=
pts_cur
[
i
]
x_old
,
y_old
=
self
.
pts_last
[
i
]
tmp
=
(
x_new
-
x_old
)
**
2
+
(
y_new
-
y_old
)
**
2
w
=
np
.
exp
(
-
tmp
/
(
width
*
self
.
smoothAlpha
))
x
=
x_old
*
w
+
x_new
*
(
1
-
w
)
y
=
y_old
*
w
+
y_new
*
(
1
-
w
)
pts_update
.
append
([
x
,
y
])
pts_update
=
np
.
array
(
pts_update
)
self
.
pts_last
=
pts_update
.
copy
()
return
pts_update
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment