Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
STAR
Commits
c9ce7f39
Commit
c9ce7f39
authored
Dec 10, 2025
by
yangzhong
Browse files
rename unet_v2v_init.py
parent
1f5da520
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2319 additions
and
0 deletions
+2319
-0
video_to_video/modules/unet_v2v_init.py
video_to_video/modules/unet_v2v_init.py
+2319
-0
No files found.
video_to_video/modules/unet_v2v_init.py
0 → 100644
View file @
c9ce7f39
# Copyright (c) Alibaba, Inc. and its affiliates.
import
math
import
os
from
abc
import
abstractmethod
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
xformers
import
xformers.ops
from
einops
import
rearrange
from
fairscale.nn.checkpoint
import
checkpoint_wrapper
from
timm.models.vision_transformer
import
Mlp
USE_TEMPORAL_TRANSFORMER
=
True
class
CaptionEmbedder
(
nn
.
Module
):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def
__init__
(
self
,
in_channels
,
hidden_size
,
uncond_prob
,
act_layer
=
nn
.
GELU
(
approximate
=
"tanh"
),
token_num
=
120
):
super
().
__init__
()
self
.
y_proj
=
Mlp
(
in_features
=
in_channels
,
hidden_features
=
hidden_size
,
out_features
=
hidden_size
,
act_layer
=
act_layer
,
drop
=
0
)
self
.
register_buffer
(
"y_embedding"
,
nn
.
Parameter
(
torch
.
randn
(
token_num
,
in_channels
)
/
in_channels
**
0.5
))
self
.
uncond_prob
=
uncond_prob
def
token_drop
(
self
,
caption
,
force_drop_ids
=
None
):
"""
Drops labels to enable classifier-free guidance.
"""
if
force_drop_ids
is
None
:
drop_ids
=
torch
.
rand
(
caption
.
shape
[
0
]).
cuda
()
<
self
.
uncond_prob
else
:
drop_ids
=
force_drop_ids
==
1
caption
=
torch
.
where
(
drop_ids
[:,
None
,
None
,
None
],
self
.
y_embedding
,
caption
)
return
caption
def
forward
(
self
,
caption
,
train
,
force_drop_ids
=
None
):
if
train
:
assert
caption
.
shape
[
2
:]
==
self
.
y_embedding
.
shape
use_dropout
=
self
.
uncond_prob
>
0
if
(
train
and
use_dropout
)
or
(
force_drop_ids
is
not
None
):
caption
=
self
.
token_drop
(
caption
,
force_drop_ids
)
caption
=
self
.
y_proj
(
caption
)
return
caption
class
DropPath
(
nn
.
Module
):
r
"""DropPath but without rescaling and supports optional all-zero and/or all-keep.
"""
def
__init__
(
self
,
p
):
super
(
DropPath
,
self
).
__init__
()
self
.
p
=
p
def
forward
(
self
,
*
args
,
zero
=
None
,
keep
=
None
):
if
not
self
.
training
:
return
args
[
0
]
if
len
(
args
)
==
1
else
args
# params
x
=
args
[
0
]
b
=
x
.
size
(
0
)
n
=
(
torch
.
rand
(
b
)
<
self
.
p
).
sum
()
# non-zero and non-keep mask
mask
=
x
.
new_ones
(
b
,
dtype
=
torch
.
bool
)
if
keep
is
not
None
:
mask
[
keep
]
=
False
if
zero
is
not
None
:
mask
[
zero
]
=
False
# drop-path index
index
=
torch
.
where
(
mask
)[
0
]
index
=
index
[
torch
.
randperm
(
len
(
index
))[:
n
]]
if
zero
is
not
None
:
index
=
torch
.
cat
([
index
,
torch
.
where
(
zero
)[
0
]],
dim
=
0
)
# drop-path multiplier
multiplier
=
x
.
new_ones
(
b
)
multiplier
[
index
]
=
0.0
output
=
tuple
(
u
*
self
.
broadcast
(
multiplier
,
u
)
for
u
in
args
)
return
output
[
0
]
if
len
(
args
)
==
1
else
output
def
broadcast
(
self
,
src
,
dst
):
assert
src
.
size
(
0
)
==
dst
.
size
(
0
)
shape
=
(
dst
.
size
(
0
),
)
+
(
1
,
)
*
(
dst
.
ndim
-
1
)
return
src
.
view
(
shape
)
def
sinusoidal_embedding
(
timesteps
,
dim
):
# check input
half
=
dim
//
2
timesteps
=
timesteps
.
float
()
# compute sinusoidal embedding
sinusoid
=
torch
.
outer
(
timesteps
,
torch
.
pow
(
10000
,
-
torch
.
arange
(
half
).
to
(
timesteps
).
div
(
half
)))
x
=
torch
.
cat
([
torch
.
cos
(
sinusoid
),
torch
.
sin
(
sinusoid
)],
dim
=
1
)
if
dim
%
2
!=
0
:
x
=
torch
.
cat
([
x
,
torch
.
zeros_like
(
x
[:,
:
1
])],
dim
=
1
)
return
x
def
exists
(
x
):
return
x
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
callable
(
d
)
else
d
def
prob_mask_like
(
shape
,
prob
,
device
):
if
prob
==
1
:
return
torch
.
ones
(
shape
,
device
=
device
,
dtype
=
torch
.
bool
)
elif
prob
==
0
:
return
torch
.
zeros
(
shape
,
device
=
device
,
dtype
=
torch
.
bool
)
else
:
mask
=
torch
.
zeros
(
shape
,
device
=
device
).
float
().
uniform_
(
0
,
1
)
<
prob
# aviod mask all, which will cause find_unused_parameters error
if
mask
.
all
():
mask
[
0
]
=
False
return
mask
class
MemoryEfficientCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
max_bs
=
16384
,
dropout
=
0.0
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
self
.
max_bs
=
max_bs
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
self
.
attention_op
:
Optional
[
Any
]
=
None
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
b
,
_
,
_
=
q
.
shape
q
,
k
,
v
=
map
(
lambda
t
:
t
.
unsqueeze
(
3
).
reshape
(
b
,
t
.
shape
[
1
],
self
.
heads
,
self
.
dim_head
).
permute
(
0
,
2
,
1
,
3
).
reshape
(
b
*
self
.
heads
,
t
.
shape
[
1
],
self
.
dim_head
).
contiguous
(),
(
q
,
k
,
v
),
)
# actually compute the attention, what we cannot get enough of.
if
q
.
shape
[
0
]
>
self
.
max_bs
:
q_list
=
torch
.
chunk
(
q
,
q
.
shape
[
0
]
//
self
.
max_bs
,
dim
=
0
)
k_list
=
torch
.
chunk
(
k
,
k
.
shape
[
0
]
//
self
.
max_bs
,
dim
=
0
)
v_list
=
torch
.
chunk
(
v
,
v
.
shape
[
0
]
//
self
.
max_bs
,
dim
=
0
)
out_list
=
[]
for
q_1
,
k_1
,
v_1
in
zip
(
q_list
,
k_list
,
v_list
):
out
=
xformers
.
ops
.
memory_efficient_attention
(
q_1
,
k_1
,
v_1
,
attn_bias
=
None
,
op
=
self
.
attention_op
)
out_list
.
append
(
out
)
out
=
torch
.
cat
(
out_list
,
dim
=
0
)
else
:
out
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
,
op
=
self
.
attention_op
)
if
exists
(
mask
):
raise
NotImplementedError
out
=
(
out
.
unsqueeze
(
0
).
reshape
(
b
,
self
.
heads
,
out
.
shape
[
1
],
self
.
dim_head
).
permute
(
0
,
2
,
1
,
3
).
reshape
(
b
,
out
.
shape
[
1
],
self
.
heads
*
self
.
dim_head
))
return
self
.
to_out
(
out
)
class
RelativePositionBias
(
nn
.
Module
):
def
__init__
(
self
,
heads
=
8
,
num_buckets
=
32
,
max_distance
=
128
):
super
().
__init__
()
self
.
num_buckets
=
num_buckets
self
.
max_distance
=
max_distance
self
.
relative_attention_bias
=
nn
.
Embedding
(
num_buckets
,
heads
)
@
staticmethod
def
_relative_position_bucket
(
relative_position
,
num_buckets
=
32
,
max_distance
=
128
):
ret
=
0
n
=
-
relative_position
num_buckets
//=
2
ret
+=
(
n
<
0
).
long
()
*
num_buckets
n
=
torch
.
abs
(
n
)
max_exact
=
num_buckets
//
2
is_small
=
n
<
max_exact
val_if_large
=
max_exact
+
(
torch
.
log
(
n
.
float
()
/
max_exact
)
/
math
.
log
(
max_distance
/
max_exact
)
*
# noqa
(
num_buckets
-
max_exact
)).
long
()
val_if_large
=
torch
.
min
(
val_if_large
,
torch
.
full_like
(
val_if_large
,
num_buckets
-
1
))
ret
+=
torch
.
where
(
is_small
,
n
,
val_if_large
)
return
ret
def
forward
(
self
,
n
,
device
):
q_pos
=
torch
.
arange
(
n
,
dtype
=
torch
.
long
,
device
=
device
)
k_pos
=
torch
.
arange
(
n
,
dtype
=
torch
.
long
,
device
=
device
)
rel_pos
=
rearrange
(
k_pos
,
'j -> 1 j'
)
-
rearrange
(
q_pos
,
'i -> i 1'
)
rp_bucket
=
self
.
_relative_position_bucket
(
rel_pos
,
num_buckets
=
self
.
num_buckets
,
max_distance
=
self
.
max_distance
)
values
=
self
.
relative_attention_bias
(
rp_bucket
)
return
rearrange
(
values
,
'i j h -> h i j'
)
class
SpatialTransformer
(
nn
.
Module
):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.
,
context_dim
=
None
,
disable_self_attn
=
False
,
use_linear
=
False
,
use_checkpoint
=
True
,
is_ctrl
=
False
):
super
().
__init__
()
if
exists
(
context_dim
)
and
not
isinstance
(
context_dim
,
list
):
context_dim
=
[
context_dim
]
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
if
not
use_linear
:
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
else
:
self
.
proj_in
=
nn
.
Linear
(
in_channels
,
inner_dim
)
self
.
transformer_blocks
=
nn
.
ModuleList
([
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
[
d
],
disable_self_attn
=
disable_self_attn
,
checkpoint
=
use_checkpoint
,
local_type
=
'space'
,
is_ctrl
=
is_ctrl
)
for
d
in
range
(
depth
)
])
if
not
use_linear
:
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
else
:
self
.
proj_out
=
zero_module
(
nn
.
Linear
(
in_channels
,
inner_dim
))
self
.
use_linear
=
use_linear
def
forward
(
self
,
x
,
context
=
None
):
# note: if no context is given, cross-attention defaults to self-attention
if
not
isinstance
(
context
,
list
):
context
=
[
context
]
_
,
_
,
h
,
w
=
x
.
shape
# print('x shape:', x.shape) # [64, 320, 90, 160]
x_in
=
x
x
=
self
.
norm
(
x
)
if
not
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
x
=
rearrange
(
x
,
'b c h w -> b (h w) c'
).
contiguous
()
if
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
x
=
block
(
x
,
context
=
context
[
i
],
h
=
h
,
w
=
w
)
if
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
x
=
rearrange
(
x
,
'b (h w) c -> b c h w'
,
h
=
h
,
w
=
w
).
contiguous
()
if
not
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
_ATTN_PRECISION
=
os
.
environ
.
get
(
'ATTN_PRECISION'
,
'fp32'
)
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
h
=
self
.
heads
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b n (h d) -> (b h) n d'
,
h
=
h
),
(
q
,
k
,
v
))
# force cast to fp32 to avoid overflowing
if
_ATTN_PRECISION
==
'fp32'
:
with
torch
.
autocast
(
enabled
=
False
,
device_type
=
'cuda'
):
q
,
k
=
q
.
float
(),
k
.
float
()
sim
=
torch
.
einsum
(
'b i d, b j d -> b i j'
,
q
,
k
)
*
self
.
scale
else
:
sim
=
torch
.
einsum
(
'b i d, b j d -> b i j'
,
q
,
k
)
*
self
.
scale
del
q
,
k
if
exists
(
mask
):
mask
=
rearrange
(
mask
,
'b ... -> b (...)'
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
mask
=
repeat
(
mask
,
'b j -> (b h) () j'
,
h
=
h
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
sim
=
sim
.
softmax
(
dim
=-
1
)
out
=
torch
.
einsum
(
'b i j, b j d -> b i d'
,
sim
,
v
)
out
=
rearrange
(
out
,
'(b h) n d -> b n (h d)'
,
h
=
h
)
return
self
.
to_out
(
out
)
class
SpatialAttention
(
nn
.
Module
):
def
__init__
(
self
):
super
(
SpatialAttention
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
in_channels
=
2
,
out_channels
=
1
,
kernel_size
=
7
,
padding
=
7
//
2
,
bias
=
False
)
self
.
sigmoid
=
nn
.
Sigmoid
()
def
forward
(
self
,
x
):
max_out
,
_
=
torch
.
max
(
x
,
dim
=
1
,
keepdim
=
True
)
avg_out
=
torch
.
mean
(
x
,
dim
=
1
,
keepdim
=
True
)
weight
=
torch
.
cat
([
max_out
,
avg_out
],
dim
=
1
)
weight
=
self
.
conv1
(
weight
)
out
=
self
.
sigmoid
(
weight
)
*
x
return
out
class
TemporalLocalAttention
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TemporalLocalAttention
,
self
).
__init__
()
self
.
conv1
=
nn
.
Linear
(
in_features
=
2
,
out_features
=
1
,
bias
=
False
)
self
.
sigmoid
=
nn
.
Sigmoid
()
def
forward
(
self
,
x
):
max_out
,
_
=
torch
.
max
(
x
,
dim
=-
1
,
keepdim
=
True
)
avg_out
=
torch
.
mean
(
x
,
dim
=-
1
,
keepdim
=
True
)
weight
=
torch
.
cat
([
max_out
,
avg_out
],
dim
=-
1
)
weight
=
self
.
conv1
(
weight
)
out
=
self
.
sigmoid
(
weight
)
*
x
return
out
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
disable_self_attn
=
False
,
local_type
=
None
,
is_ctrl
=
False
):
super
().
__init__
()
self
.
local_type
=
local_type
self
.
is_ctrl
=
is_ctrl
attn_cls
=
MemoryEfficientCrossAttention
self
.
disable_self_attn
=
disable_self_attn
self
.
attn1
=
attn_cls
(
# self-attn
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
)
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
attn_cls2
=
MemoryEfficientCrossAttention
self
.
attn2
=
attn_cls2
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
checkpoint
=
checkpoint
if
self
.
local_type
==
'space'
and
self
.
is_ctrl
:
self
.
local1
=
SpatialAttention
()
if
self
.
local_type
==
'temp'
and
self
.
is_ctrl
:
self
.
local1
=
TemporalLocalAttention
()
self
.
local2
=
TemporalLocalAttention
()
def
forward_
(
self
,
x
,
context
=
None
):
return
checkpoint
(
self
.
_forward
,
(
x
,
context
),
self
.
parameters
(),
self
.
checkpoint
)
def
forward
(
self
,
x
,
context
=
None
,
h
=
None
,
w
=
None
):
if
self
.
local_type
==
'space'
and
self
.
is_ctrl
:
# [b*t,(hw), c]
x_local
=
rearrange
(
x
,
'b (h w) c -> b c h w'
,
h
=
h
)
x_local
=
self
.
local1
(
x_local
)
x_local
=
rearrange
(
x_local
,
'b c h w -> b (h w) c'
)
x
=
self
.
attn1
(
self
.
norm1
(
x_local
),
context
=
context
if
self
.
disable_self_attn
else
None
)
+
x
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context
)
+
x
# cross attention or self-attention
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
if
self
.
local_type
==
'temp'
and
self
.
is_ctrl
:
x_local
=
self
.
local1
(
x
)
x
=
self
.
attn1
(
self
.
norm1
(
x_local
),
context
=
context
if
self
.
disable_self_attn
else
None
)
+
x
x_local
=
self
.
local2
(
x
)
x
=
self
.
attn2
(
self
.
norm2
(
x_local
),
context
=
context
)
+
x
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
return
x
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
())
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
))
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
nn
.
Conv2d
(
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
'nearest'
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
'nearest'
)
x
=
x
[...,
1
:
-
1
,
:]
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
ResBlock
(
nn
.
Module
):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
up
=
False
,
down
=
False
,
use_temporal_conv
=
True
,
use_image_dataset
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
use_temporal_conv
=
use_temporal_conv
self
.
in_layers
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
channels
),
nn
.
SiLU
(),
nn
.
Conv2d
(
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
self
.
out_channels
),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
nn
.
Conv2d
(
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
nn
.
Conv2d
(
channels
,
self
.
out_channels
,
1
)
if
self
.
use_temporal_conv
:
self
.
temopral_conv
=
TemporalConvBlock_v2
(
self
.
out_channels
,
self
.
out_channels
,
dropout
=
0.1
,
use_image_dataset
=
use_image_dataset
)
def
forward
(
self
,
x
,
emb
,
batch_size
,
variant_info
=
None
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return
self
.
_forward
(
x
,
emb
,
batch_size
,
variant_info
)
def
_forward
(
self
,
x
,
emb
,
batch_size
,
variant_info
):
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
th
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
h
=
self
.
skip_connection
(
x
)
+
h
if
self
.
use_temporal_conv
:
h
=
rearrange
(
h
,
'(b f) c h w -> b c f h w'
,
b
=
batch_size
)
h
=
self
.
temopral_conv
(
h
,
variant_info
=
variant_info
)
h
=
rearrange
(
h
,
'b c f h w -> (b f) c h w'
)
return
h
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
(
2
,
1
)):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
nn
.
Conv2d
(
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
op
(
x
)
class
Resample
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
mode
):
assert
mode
in
[
'none'
,
'upsample'
,
'downsample'
]
super
(
Resample
,
self
).
__init__
()
self
.
in_dim
=
in_dim
self
.
out_dim
=
out_dim
self
.
mode
=
mode
def
forward
(
self
,
x
,
reference
=
None
):
if
self
.
mode
==
'upsample'
:
assert
reference
is
not
None
x
=
F
.
interpolate
(
x
,
size
=
reference
.
shape
[
-
2
:],
mode
=
'nearest'
)
elif
self
.
mode
==
'downsample'
:
x
=
F
.
adaptive_avg_pool2d
(
x
,
output_size
=
tuple
(
u
//
2
for
u
in
x
.
shape
[
-
2
:]))
return
x
class
ResidualBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
embed_dim
,
out_dim
,
use_scale_shift_norm
=
True
,
mode
=
'none'
,
dropout
=
0.0
):
super
(
ResidualBlock
,
self
).
__init__
()
self
.
in_dim
=
in_dim
self
.
embed_dim
=
embed_dim
self
.
out_dim
=
out_dim
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
mode
=
mode
# layers
self
.
layer1
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
in_dim
),
nn
.
SiLU
(),
nn
.
Conv2d
(
in_dim
,
out_dim
,
3
,
padding
=
1
))
self
.
resample
=
Resample
(
in_dim
,
in_dim
,
mode
)
self
.
embedding
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
embed_dim
,
out_dim
*
2
if
use_scale_shift_norm
else
out_dim
))
self
.
layer2
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
out_dim
),
nn
.
SiLU
(),
nn
.
Dropout
(
dropout
),
nn
.
Conv2d
(
out_dim
,
out_dim
,
3
,
padding
=
1
))
self
.
shortcut
=
nn
.
Identity
()
if
in_dim
==
out_dim
else
nn
.
Conv2d
(
in_dim
,
out_dim
,
1
)
# zero out the last layer params
nn
.
init
.
zeros_
(
self
.
layer2
[
-
1
].
weight
)
def
forward
(
self
,
x
,
e
,
reference
=
None
):
identity
=
self
.
resample
(
x
,
reference
)
x
=
self
.
layer1
[
-
1
](
self
.
resample
(
self
.
layer1
[:
-
1
](
x
),
reference
))
e
=
self
.
embedding
(
e
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
type
(
x
.
dtype
)
if
self
.
use_scale_shift_norm
:
scale
,
shift
=
e
.
chunk
(
2
,
dim
=
1
)
x
=
self
.
layer2
[
0
](
x
)
*
(
1
+
scale
)
+
shift
x
=
self
.
layer2
[
1
:](
x
)
else
:
x
=
x
+
e
x
=
self
.
layer2
(
x
)
x
=
x
+
self
.
shortcut
(
identity
)
return
x
class
AttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
context_dim
=
None
,
num_heads
=
None
,
head_dim
=
None
):
# consider head_dim first, then num_heads
num_heads
=
dim
//
head_dim
if
head_dim
else
num_heads
head_dim
=
dim
//
num_heads
assert
num_heads
*
head_dim
==
dim
super
(
AttentionBlock
,
self
).
__init__
()
self
.
dim
=
dim
self
.
context_dim
=
context_dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
head_dim
self
.
scale
=
math
.
pow
(
head_dim
,
-
0.25
)
# layers
self
.
norm
=
nn
.
GroupNorm
(
32
,
dim
)
self
.
to_qkv
=
nn
.
Conv2d
(
dim
,
dim
*
3
,
1
)
if
context_dim
is
not
None
:
self
.
context_kv
=
nn
.
Linear
(
context_dim
,
dim
*
2
)
self
.
proj
=
nn
.
Conv2d
(
dim
,
dim
,
1
)
# zero out the last layer params
nn
.
init
.
zeros_
(
self
.
proj
.
weight
)
def
forward
(
self
,
x
,
context
=
None
):
r
"""x: [B, C, H, W].
context: [B, L, C] or None.
"""
identity
=
x
b
,
c
,
h
,
w
,
n
,
d
=
*
x
.
size
(),
self
.
num_heads
,
self
.
head_dim
# compute query, key, value
x
=
self
.
norm
(
x
)
q
,
k
,
v
=
self
.
to_qkv
(
x
).
view
(
b
,
n
*
3
,
d
,
h
*
w
).
chunk
(
3
,
dim
=
1
)
if
context
is
not
None
:
ck
,
cv
=
self
.
context_kv
(
context
).
reshape
(
b
,
-
1
,
n
*
2
,
d
).
permute
(
0
,
2
,
3
,
1
).
chunk
(
2
,
dim
=
1
)
k
=
torch
.
cat
([
ck
,
k
],
dim
=-
1
)
v
=
torch
.
cat
([
cv
,
v
],
dim
=-
1
)
# compute attention
attn
=
torch
.
matmul
(
q
.
transpose
(
-
1
,
-
2
)
*
self
.
scale
,
k
*
self
.
scale
)
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
# gather context
x
=
torch
.
matmul
(
v
,
attn
.
transpose
(
-
1
,
-
2
))
x
=
x
.
reshape
(
b
,
c
,
h
,
w
)
# output
x
=
self
.
proj
(
x
)
return
x
+
identity
class
TemporalAttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
,
rotary_emb
=
None
,
use_image_dataset
=
False
,
use_sim_mask
=
False
):
super
().
__init__
()
# consider num_heads first, as pos_bias needs fixed num_heads
dim_head
=
dim
//
heads
assert
heads
*
dim_head
==
dim
self
.
use_image_dataset
=
use_image_dataset
self
.
use_sim_mask
=
use_sim_mask
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
hidden_dim
=
dim_head
*
heads
self
.
norm
=
nn
.
GroupNorm
(
32
,
dim
)
self
.
rotary_emb
=
rotary_emb
self
.
to_qkv
=
nn
.
Linear
(
dim
,
hidden_dim
*
3
)
self
.
to_out
=
nn
.
Linear
(
hidden_dim
,
dim
)
def
forward
(
self
,
x
,
pos_bias
=
None
,
focus_present_mask
=
None
,
video_mask
=
None
):
identity
=
x
n
,
height
,
device
=
x
.
shape
[
2
],
x
.
shape
[
-
2
],
x
.
device
x
=
self
.
norm
(
x
)
x
=
rearrange
(
x
,
'b c f h w -> b (h w) f c'
)
qkv
=
self
.
to_qkv
(
x
).
chunk
(
3
,
dim
=-
1
)
if
exists
(
focus_present_mask
)
and
focus_present_mask
.
all
():
# if all batch samples are focusing on present
# it would be equivalent to passing that token's values (v=qkv[-1]) through to the output
values
=
qkv
[
-
1
]
out
=
self
.
to_out
(
values
)
out
=
rearrange
(
out
,
'b (h w) f c -> b c f h w'
,
h
=
height
)
return
out
+
identity
# split out heads
q
=
rearrange
(
qkv
[
0
],
'... n (h d) -> ... h n d'
,
h
=
self
.
heads
)
k
=
rearrange
(
qkv
[
1
],
'... n (h d) -> ... h n d'
,
h
=
self
.
heads
)
v
=
rearrange
(
qkv
[
2
],
'... n (h d) -> ... h n d'
,
h
=
self
.
heads
)
# scale
q
=
q
*
self
.
scale
# rotate positions into queries and keys for time attention
if
exists
(
self
.
rotary_emb
):
q
=
self
.
rotary_emb
.
rotate_queries_or_keys
(
q
)
k
=
self
.
rotary_emb
.
rotate_queries_or_keys
(
k
)
# similarity
# shape [b (hw) h n n], n=f
sim
=
torch
.
einsum
(
'... h i d, ... h j d -> ... h i j'
,
q
,
k
)
# relative positional bias
if
exists
(
pos_bias
):
sim
=
sim
+
pos_bias
if
(
focus_present_mask
is
None
and
video_mask
is
not
None
):
# video_mask: [B, n]
mask
=
video_mask
[:,
None
,
:]
*
video_mask
[:,
:,
None
]
mask
=
mask
.
unsqueeze
(
1
).
unsqueeze
(
1
)
sim
=
sim
.
masked_fill
(
~
mask
,
-
torch
.
finfo
(
sim
.
dtype
).
max
)
elif
exists
(
focus_present_mask
)
and
not
(
~
focus_present_mask
).
all
():
attend_all_mask
=
torch
.
ones
((
n
,
n
),
device
=
device
,
dtype
=
torch
.
bool
)
attend_self_mask
=
torch
.
eye
(
n
,
device
=
device
,
dtype
=
torch
.
bool
)
mask
=
torch
.
where
(
rearrange
(
focus_present_mask
,
'b -> b 1 1 1 1'
),
rearrange
(
attend_self_mask
,
'i j -> 1 1 1 i j'
),
rearrange
(
attend_all_mask
,
'i j -> 1 1 1 i j'
),
)
sim
=
sim
.
masked_fill
(
~
mask
,
-
torch
.
finfo
(
sim
.
dtype
).
max
)
if
self
.
use_sim_mask
:
sim_mask
=
torch
.
tril
(
torch
.
ones
((
n
,
n
),
device
=
device
,
dtype
=
torch
.
bool
),
diagonal
=
0
)
sim
=
sim
.
masked_fill
(
~
sim_mask
,
-
torch
.
finfo
(
sim
.
dtype
).
max
)
# numerical stability
sim
=
sim
-
sim
.
amax
(
dim
=-
1
,
keepdim
=
True
).
detach
()
attn
=
sim
.
softmax
(
dim
=-
1
)
# aggregate values
out
=
torch
.
einsum
(
'... h i j, ... h j d -> ... h i d'
,
attn
,
v
)
out
=
rearrange
(
out
,
'... h n d -> ... n (h d)'
)
out
=
self
.
to_out
(
out
)
out
=
rearrange
(
out
,
'b (h w) f c -> b c f h w'
,
h
=
height
)
if
self
.
use_image_dataset
:
out
=
identity
+
0
*
out
else
:
out
=
identity
+
out
return
out
class
TemporalTransformer
(
nn
.
Module
):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.
,
context_dim
=
None
,
disable_self_attn
=
False
,
use_linear
=
False
,
use_checkpoint
=
True
,
only_self_att
=
True
,
multiply_zero
=
False
,
is_ctrl
=
False
):
super
().
__init__
()
self
.
multiply_zero
=
multiply_zero
self
.
only_self_att
=
only_self_att
self
.
use_adaptor
=
False
if
self
.
only_self_att
:
context_dim
=
None
if
not
isinstance
(
context_dim
,
list
):
context_dim
=
[
context_dim
]
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
if
not
use_linear
:
self
.
proj_in
=
nn
.
Conv1d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
else
:
self
.
proj_in
=
nn
.
Linear
(
in_channels
,
inner_dim
)
if
self
.
use_adaptor
:
self
.
adaptor_in
=
nn
.
Linear
(
frames
,
frames
)
self
.
transformer_blocks
=
nn
.
ModuleList
([
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
[
d
],
checkpoint
=
use_checkpoint
,
local_type
=
'temp'
,
is_ctrl
=
is_ctrl
)
for
d
in
range
(
depth
)
])
if
not
use_linear
:
self
.
proj_out
=
zero_module
(
nn
.
Conv1d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
else
:
self
.
proj_out
=
zero_module
(
nn
.
Linear
(
inner_dim
,
in_channels
))
if
self
.
use_adaptor
:
self
.
adaptor_out
=
nn
.
Linear
(
frames
,
frames
)
self
.
use_linear
=
use_linear
def
forward
(
self
,
x
,
context
=
None
):
# note: if no context is given, cross-attention defaults to self-attention
if
self
.
only_self_att
:
context
=
None
if
not
isinstance
(
context
,
list
):
context
=
[
context
]
b
,
_
,
_
,
h
,
w
=
x
.
shape
x_in
=
x
x
=
self
.
norm
(
x
)
if
not
self
.
use_linear
:
x
=
rearrange
(
x
,
'b c f h w -> (b h w) c f'
).
contiguous
()
x
=
self
.
proj_in
(
x
)
if
self
.
use_linear
:
x
=
rearrange
(
x
,
'b c f h w -> (b h w) f c'
).
contiguous
()
x
=
self
.
proj_in
(
x
)
x
=
rearrange
(
x
,
'bhw f c -> bhw c f'
).
contiguous
()
# print('x shape:', x.shape) # [28800, 512, 32]
if
self
.
only_self_att
:
# no cross-attention
x
=
rearrange
(
x
,
'bhw c f -> bhw f c'
).
contiguous
()
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
x
=
block
(
x
,
h
=
h
,
w
=
w
)
# print('x shape:', x.shape) # [43200, 32, 512]
x
=
rearrange
(
x
,
'(b hw) f c -> b hw f c'
,
b
=
b
).
contiguous
()
else
:
x
=
rearrange
(
x
,
'(b hw) c f -> b hw f c'
,
b
=
b
).
contiguous
()
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
context
[
i
]
=
rearrange
(
context
[
i
],
'(b f) l con -> b f l con'
,
f
=
self
.
frames
).
contiguous
()
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
for
j
in
range
(
b
):
context_i_j
=
repeat
(
context
[
i
][
j
],
'f l con -> (f r) l con'
,
r
=
(
h
*
w
)
//
self
.
frames
,
f
=
self
.
frames
).
contiguous
()
x
[
j
]
=
block
(
x
[
j
],
context
=
context_i_j
)
if
self
.
use_linear
:
x
=
rearrange
(
x
,
'b hw f c -> (b hw) f c'
).
contiguous
()
x
=
self
.
proj_out
(
x
)
x
=
rearrange
(
x
,
'(b h w) f c -> b c f h w'
,
b
=
b
,
h
=
h
,
w
=
w
).
contiguous
()
if
not
self
.
use_linear
:
# print('x shape:', x.shape) # [2, 21600, 32, 512]
x
=
rearrange
(
x
,
'b hw f c -> (b hw) c f'
).
contiguous
()
x
=
self
.
proj_out
(
x
)
x
=
rearrange
(
x
,
'(b h w) c f -> b c f h w'
,
b
=
b
,
h
=
h
,
w
=
w
).
contiguous
()
if
self
.
multiply_zero
:
x
=
0.0
*
x
+
x_in
else
:
x
=
x
+
x_in
return
x
class
TemporalAttentionMultiBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
,
rotary_emb
=
None
,
use_image_dataset
=
False
,
use_sim_mask
=
False
,
temporal_attn_times
=
1
,
):
super
().
__init__
()
self
.
att_layers
=
nn
.
ModuleList
([
TemporalAttentionBlock
(
dim
,
heads
,
dim_head
,
rotary_emb
,
use_image_dataset
,
use_sim_mask
)
for
_
in
range
(
temporal_attn_times
)
])
def
forward
(
self
,
x
,
pos_bias
=
None
,
focus_present_mask
=
None
,
video_mask
=
None
):
for
layer
in
self
.
att_layers
:
x
=
layer
(
x
,
pos_bias
,
focus_present_mask
,
video_mask
)
return
x
class
InitTemporalConvBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
=
None
,
dropout
=
0.0
,
use_image_dataset
=
False
):
super
(
InitTemporalConvBlock
,
self
).
__init__
()
if
out_dim
is
None
:
out_dim
=
in_dim
self
.
in_dim
=
in_dim
self
.
out_dim
=
out_dim
self
.
use_image_dataset
=
use_image_dataset
# conv layers
self
.
conv
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
out_dim
),
nn
.
SiLU
(),
nn
.
Dropout
(
dropout
),
nn
.
Conv3d
(
out_dim
,
in_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
)))
# zero out the last layer params,so the conv block is identity
nn
.
init
.
zeros_
(
self
.
conv
[
-
1
].
weight
)
nn
.
init
.
zeros_
(
self
.
conv
[
-
1
].
bias
)
def
forward
(
self
,
x
):
identity
=
x
x
=
self
.
conv
(
x
)
if
self
.
use_image_dataset
:
x
=
identity
+
0
*
x
else
:
x
=
identity
+
x
return
x
class
TemporalConvBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
=
None
,
dropout
=
0.0
,
use_image_dataset
=
False
):
super
(
TemporalConvBlock
,
self
).
__init__
()
if
out_dim
is
None
:
out_dim
=
in_dim
self
.
in_dim
=
in_dim
self
.
out_dim
=
out_dim
self
.
use_image_dataset
=
use_image_dataset
# conv layers
self
.
conv1
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
in_dim
),
nn
.
SiLU
(),
nn
.
Conv3d
(
in_dim
,
out_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
)))
self
.
conv2
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
out_dim
),
nn
.
SiLU
(),
nn
.
Dropout
(
dropout
),
nn
.
Conv3d
(
out_dim
,
in_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
)))
# zero out the last layer params,so the conv block is identity
nn
.
init
.
zeros_
(
self
.
conv2
[
-
1
].
weight
)
nn
.
init
.
zeros_
(
self
.
conv2
[
-
1
].
bias
)
def
forward
(
self
,
x
):
identity
=
x
x
=
self
.
conv1
(
x
)
x
=
self
.
conv2
(
x
)
if
self
.
use_image_dataset
:
x
=
identity
+
0
*
x
else
:
x
=
identity
+
x
return
x
class
TemporalConvBlock_v2
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
=
None
,
dropout
=
0.0
,
use_image_dataset
=
False
):
super
(
TemporalConvBlock_v2
,
self
).
__init__
()
if
out_dim
is
None
:
out_dim
=
in_dim
self
.
in_dim
=
in_dim
self
.
out_dim
=
out_dim
self
.
use_image_dataset
=
use_image_dataset
# conv layers
self
.
conv1
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
in_dim
),
nn
.
SiLU
(),
nn
.
Conv3d
(
in_dim
,
out_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
)))
self
.
conv2
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
out_dim
),
nn
.
SiLU
(),
nn
.
Dropout
(
dropout
),
nn
.
Conv3d
(
out_dim
,
in_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
)))
self
.
conv3
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
out_dim
),
nn
.
SiLU
(),
nn
.
Dropout
(
dropout
),
nn
.
Conv3d
(
out_dim
,
in_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
)))
self
.
conv4
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
out_dim
),
nn
.
SiLU
(),
nn
.
Dropout
(
dropout
),
nn
.
Conv3d
(
out_dim
,
in_dim
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
)))
# zero out the last layer params,so the conv block is identity
nn
.
init
.
zeros_
(
self
.
conv4
[
-
1
].
weight
)
nn
.
init
.
zeros_
(
self
.
conv4
[
-
1
].
bias
)
def
forward
(
self
,
x
,
variant_info
=
None
):
if
variant_info
is
not
None
and
variant_info
.
get
(
'type'
)
==
'variant2'
:
# print(x.shape) # torch.Size([1, 320, 32, 90, 160])
_
,
_
,
f
,
_
,
_
=
x
.
shape
assert
f
%
4
==
0
,
"f must be divisible by 4"
x_short
=
rearrange
(
x
,
"b c (n s) h w -> (n b) c s h w"
,
n
=
4
)
x_short
=
self
.
conv1
(
x_short
)
x_short
=
self
.
conv2
(
x_short
)
x_short
=
self
.
conv3
(
x_short
)
x_short
=
self
.
conv4
(
x_short
)
x_short
=
rearrange
(
x_short
,
"(n b) c s h w -> b c (n s) h w"
,
n
=
4
)
identity
=
x
x
=
self
.
conv1
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
conv3
(
x
)
x
=
self
.
conv4
(
x
)
x
=
x
*
(
1
-
variant_info
[
'alpha'
])
+
x_short
*
variant_info
[
'alpha'
]
elif
variant_info
is
not
None
and
variant_info
.
get
(
'type'
)
==
'variant1'
:
identity
=
x
x_long
,
x_short
=
x
.
chunk
(
2
,
dim
=
0
)
x_short
=
rearrange
(
x_short
,
"b c (n s) h w -> (n b) c s h w"
,
n
=
4
)
x_short
=
self
.
conv1
(
x_short
)
x_short
=
self
.
conv2
(
x_short
)
x_short
=
self
.
conv3
(
x_short
)
x_short
=
self
.
conv4
(
x_short
)
x_short
=
rearrange
(
x_short
,
"(n b) c s h w -> b c (n s) h w"
,
n
=
4
)
x_long
=
self
.
conv1
(
x_long
)
x_long
=
self
.
conv2
(
x_long
)
x_long
=
self
.
conv3
(
x_long
)
x_long
=
self
.
conv4
(
x_long
)
x
=
torch
.
cat
([
x_long
,
x_short
],
dim
=
0
)
elif
variant_info
is
None
:
identity
=
x
x
=
self
.
conv1
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
conv3
(
x
)
x
=
self
.
conv4
(
x
)
if
self
.
use_image_dataset
:
x
=
identity
+
0.0
*
x
else
:
x
=
identity
+
x
return
x
class
Vid2VidSDUNet
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
=
4
,
dim
=
320
,
y_dim
=
1024
,
context_dim
=
1024
,
out_dim
=
4
,
dim_mult
=
[
1
,
2
,
4
,
4
],
num_heads
=
8
,
head_dim
=
64
,
num_res_blocks
=
2
,
attn_scales
=
[
1
/
1
,
1
/
2
,
1
/
4
],
use_scale_shift_norm
=
True
,
dropout
=
0.1
,
temporal_attn_times
=
1
,
temporal_attention
=
True
,
use_checkpoint
=
True
,
use_image_dataset
=
False
,
use_fps_condition
=
False
,
use_sim_mask
=
False
,
training
=
False
,
inpainting
=
True
):
embed_dim
=
dim
*
4
num_heads
=
num_heads
if
num_heads
else
dim
//
32
super
(
Vid2VidSDUNet
,
self
).
__init__
()
self
.
in_dim
=
in_dim
self
.
dim
=
dim
self
.
y_dim
=
y_dim
self
.
context_dim
=
context_dim
self
.
embed_dim
=
embed_dim
self
.
out_dim
=
out_dim
self
.
dim_mult
=
dim_mult
# for temporal attention
self
.
num_heads
=
num_heads
# for spatial attention
self
.
head_dim
=
head_dim
self
.
num_res_blocks
=
num_res_blocks
self
.
attn_scales
=
attn_scales
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
temporal_attn_times
=
temporal_attn_times
self
.
temporal_attention
=
temporal_attention
self
.
use_checkpoint
=
use_checkpoint
self
.
use_image_dataset
=
use_image_dataset
self
.
use_fps_condition
=
use_fps_condition
self
.
use_sim_mask
=
use_sim_mask
self
.
training
=
training
self
.
inpainting
=
inpainting
use_linear_in_temporal
=
False
transformer_depth
=
1
disabled_sa
=
False
# params
enc_dims
=
[
dim
*
u
for
u
in
[
1
]
+
dim_mult
]
dec_dims
=
[
dim
*
u
for
u
in
[
dim_mult
[
-
1
]]
+
dim_mult
[::
-
1
]]
shortcut_dims
=
[]
scale
=
1.0
# embeddings
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
embed_dim
,
embed_dim
))
if
self
.
use_fps_condition
:
self
.
fps_embedding
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
embed_dim
,
embed_dim
))
nn
.
init
.
zeros_
(
self
.
fps_embedding
[
-
1
].
weight
)
nn
.
init
.
zeros_
(
self
.
fps_embedding
[
-
1
].
bias
)
# encoder
self
.
input_blocks
=
nn
.
ModuleList
()
init_block
=
nn
.
ModuleList
([
nn
.
Conv2d
(
self
.
in_dim
,
dim
,
3
,
padding
=
1
)])
# need an initial temporal attention?
if
temporal_attention
:
if
USE_TEMPORAL_TRANSFORMER
:
init_block
.
append
(
TemporalTransformer
(
dim
,
num_heads
,
head_dim
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_temporal
,
multiply_zero
=
use_image_dataset
,
is_ctrl
=
True
))
else
:
init_block
.
append
(
TemporalAttentionMultiBlock
(
dim
,
num_heads
,
head_dim
,
rotary_emb
=
self
.
rotary_emb
,
temporal_attn_times
=
temporal_attn_times
,
use_image_dataset
=
use_image_dataset
))
self
.
input_blocks
.
append
(
init_block
)
shortcut_dims
.
append
(
dim
)
for
i
,
(
in_dim
,
out_dim
)
in
enumerate
(
zip
(
enc_dims
[:
-
1
],
enc_dims
[
1
:])):
for
j
in
range
(
num_res_blocks
):
block
=
nn
.
ModuleList
([
ResBlock
(
in_dim
,
embed_dim
,
dropout
,
out_channels
=
out_dim
,
use_scale_shift_norm
=
False
,
use_image_dataset
=
use_image_dataset
,
)
])
if
scale
in
attn_scales
:
block
.
append
(
SpatialTransformer
(
out_dim
,
out_dim
//
head_dim
,
head_dim
,
depth
=
1
,
context_dim
=
self
.
context_dim
,
disable_self_attn
=
False
,
use_linear
=
True
,
is_ctrl
=
True
))
if
self
.
temporal_attention
:
if
USE_TEMPORAL_TRANSFORMER
:
block
.
append
(
TemporalTransformer
(
out_dim
,
out_dim
//
head_dim
,
head_dim
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_temporal
,
multiply_zero
=
use_image_dataset
,
is_ctrl
=
True
))
else
:
block
.
append
(
TemporalAttentionMultiBlock
(
out_dim
,
num_heads
,
head_dim
,
rotary_emb
=
self
.
rotary_emb
,
use_image_dataset
=
use_image_dataset
,
use_sim_mask
=
use_sim_mask
,
temporal_attn_times
=
temporal_attn_times
))
in_dim
=
out_dim
self
.
input_blocks
.
append
(
block
)
shortcut_dims
.
append
(
out_dim
)
# downsample
if
i
!=
len
(
dim_mult
)
-
1
and
j
==
num_res_blocks
-
1
:
downsample
=
Downsample
(
out_dim
,
True
,
dims
=
2
,
out_channels
=
out_dim
)
shortcut_dims
.
append
(
out_dim
)
scale
/=
2.0
self
.
input_blocks
.
append
(
downsample
)
self
.
middle_block
=
nn
.
ModuleList
([
ResBlock
(
out_dim
,
embed_dim
,
dropout
,
use_scale_shift_norm
=
False
,
use_image_dataset
=
use_image_dataset
,
),
SpatialTransformer
(
out_dim
,
out_dim
//
head_dim
,
head_dim
,
depth
=
1
,
context_dim
=
self
.
context_dim
,
disable_self_attn
=
False
,
use_linear
=
True
,
is_ctrl
=
True
)
])
if
self
.
temporal_attention
:
if
USE_TEMPORAL_TRANSFORMER
:
self
.
middle_block
.
append
(
TemporalTransformer
(
out_dim
,
out_dim
//
head_dim
,
head_dim
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_temporal
,
multiply_zero
=
use_image_dataset
,
is_ctrl
=
True
))
else
:
self
.
middle_block
.
append
(
TemporalAttentionMultiBlock
(
out_dim
,
num_heads
,
head_dim
,
rotary_emb
=
self
.
rotary_emb
,
use_image_dataset
=
use_image_dataset
,
use_sim_mask
=
use_sim_mask
,
temporal_attn_times
=
temporal_attn_times
))
self
.
middle_block
.
append
(
ResBlock
(
out_dim
,
embed_dim
,
dropout
,
use_scale_shift_norm
=
False
))
# decoder
self
.
output_blocks
=
nn
.
ModuleList
()
for
i
,
(
in_dim
,
out_dim
)
in
enumerate
(
zip
(
dec_dims
[:
-
1
],
dec_dims
[
1
:])):
for
j
in
range
(
num_res_blocks
+
1
):
block
=
nn
.
ModuleList
([
ResBlock
(
in_dim
+
shortcut_dims
.
pop
(),
embed_dim
,
dropout
,
out_dim
,
use_scale_shift_norm
=
False
,
use_image_dataset
=
use_image_dataset
,
)
])
if
scale
in
attn_scales
:
block
.
append
(
SpatialTransformer
(
out_dim
,
out_dim
//
head_dim
,
head_dim
,
depth
=
1
,
context_dim
=
1024
,
disable_self_attn
=
False
,
use_linear
=
True
,
is_ctrl
=
True
))
if
self
.
temporal_attention
:
if
USE_TEMPORAL_TRANSFORMER
:
block
.
append
(
TemporalTransformer
(
out_dim
,
out_dim
//
head_dim
,
head_dim
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_temporal
,
multiply_zero
=
use_image_dataset
,
is_ctrl
=
True
))
else
:
block
.
append
(
TemporalAttentionMultiBlock
(
out_dim
,
num_heads
,
head_dim
,
rotary_emb
=
self
.
rotary_emb
,
use_image_dataset
=
use_image_dataset
,
use_sim_mask
=
use_sim_mask
,
temporal_attn_times
=
temporal_attn_times
))
in_dim
=
out_dim
# upsample
if
i
!=
len
(
dim_mult
)
-
1
and
j
==
num_res_blocks
:
upsample
=
Upsample
(
out_dim
,
True
,
dims
=
2.0
,
out_channels
=
out_dim
)
scale
*=
2.0
block
.
append
(
upsample
)
self
.
output_blocks
.
append
(
block
)
# head
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
out_dim
),
nn
.
SiLU
(),
nn
.
Conv2d
(
out_dim
,
self
.
out_dim
,
3
,
padding
=
1
))
# zero out the last layer params
nn
.
init
.
zeros_
(
self
.
out
[
-
1
].
weight
)
def
forward
(
self
,
x
,
t
,
y
,
x_lr
=
None
,
fps
=
None
,
video_mask
=
None
,
focus_present_mask
=
None
,
prob_focus_present
=
0.
,
mask_last_frame_num
=
0
):
batch
,
c
,
f
,
h
,
w
=
x
.
shape
device
=
x
.
device
self
.
batch
=
batch
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
if
mask_last_frame_num
>
0
:
focus_present_mask
=
None
video_mask
[
-
mask_last_frame_num
:]
=
False
else
:
focus_present_mask
=
default
(
focus_present_mask
,
lambda
:
prob_mask_like
(
(
batch
,
),
prob_focus_present
,
device
=
device
))
if
self
.
temporal_attention
and
not
USE_TEMPORAL_TRANSFORMER
:
time_rel_pos_bias
=
self
.
time_rel_pos_bias
(
x
.
shape
[
2
],
device
=
x
.
device
)
else
:
time_rel_pos_bias
=
None
# embeddings
e
=
self
.
time_embed
(
sinusoidal_embedding
(
t
,
self
.
dim
))
context
=
y
# repeat f times for spatial e and context
e
=
e
.
repeat_interleave
(
repeats
=
f
,
dim
=
0
)
context
=
context
.
repeat_interleave
(
repeats
=
f
,
dim
=
0
)
# always in shape (b f) c h w, except for temporal layer
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
# encoder
xs
=
[]
for
ind
,
block
in
enumerate
(
self
.
input_blocks
):
x
=
self
.
_forward_single
(
block
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
)
xs
.
append
(
x
)
# middle
for
block
in
self
.
middle_block
:
x
=
self
.
_forward_single
(
block
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
)
# decoder
for
block
in
self
.
output_blocks
:
x
=
torch
.
cat
([
x
,
xs
.
pop
()],
dim
=
1
)
x
=
self
.
_forward_single
(
block
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
,
reference
=
xs
[
-
1
]
if
len
(
xs
)
>
0
else
None
)
# head
x
=
self
.
out
(
x
)
# reshape back to (b c f h w)
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
batch
)
return
x
def
_forward_single
(
self
,
module
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
,
reference
=
None
):
if
isinstance
(
module
,
ResidualBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
x
.
contiguous
()
x
=
module
(
x
,
e
,
reference
)
elif
isinstance
(
module
,
ResBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
x
.
contiguous
()
x
=
module
(
x
,
e
,
self
.
batch
)
elif
isinstance
(
module
,
SpatialTransformer
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
TemporalTransformer
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
x
=
module
(
x
,
context
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
CrossAttention
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
MemoryEfficientCrossAttention
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
BasicTransformerBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
FeedForward
):
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
Upsample
):
x
=
module
(
x
)
elif
isinstance
(
module
,
Downsample
):
x
=
module
(
x
)
elif
isinstance
(
module
,
Resample
):
x
=
module
(
x
,
reference
)
elif
isinstance
(
module
,
TemporalAttentionBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
x
=
module
(
x
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
TemporalAttentionMultiBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
x
=
module
(
x
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
InitTemporalConvBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
x
=
module
(
x
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
TemporalConvBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
x
=
module
(
x
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
nn
.
ModuleList
):
for
block
in
module
:
x
=
self
.
_forward_single
(
block
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
,
reference
)
else
:
x
=
module
(
x
)
return
x
class
ControlledV2VUNet
(
Vid2VidSDUNet
):
def
__init__
(
self
):
super
(
ControlledV2VUNet
,
self
).
__init__
()
self
.
VideoControlNet
=
VideoControlNet
()
def
forward
(
self
,
x
,
t
,
y
,
hint
=
None
,
variant_info
=
None
,
hint_chunk
=
None
,
t_hint
=
None
,
s_cond
=
None
,
mask_cond
=
None
,
x_lr
=
None
,
fps
=
None
,
mask
=
None
,
video_mask
=
None
,
focus_present_mask
=
None
,
prob_focus_present
=
0.
,
mask_last_frame_num
=
0
,
):
batch
,
_
,
f
,
_
,
_
=
x
.
shape
device
=
x
.
device
self
.
batch
=
batch
# Process text (new added for t5 encoder)
# y = self.VideoControlNet.y_embedder(y, self.training).squeeze(1) # [1, 1, 120, 4096] -> [B, 1, 120, 1024].squeeze(1) -> [B, 120, 1024]
if
hint_chunk
is
not
None
:
hint
=
hint_chunk
control
=
self
.
VideoControlNet
(
x
,
t
,
y
,
hint
=
hint
,
t_hint
=
t_hint
,
\
mask_cond
=
mask_cond
,
s_cond
=
s_cond
,
\
variant_info
=
variant_info
)
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
if
mask_last_frame_num
>
0
:
focus_present_mask
=
None
video_mask
[
-
mask_last_frame_num
:]
=
False
else
:
focus_present_mask
=
default
(
focus_present_mask
,
lambda
:
prob_mask_like
(
(
batch
,
),
prob_focus_present
,
device
=
device
))
if
self
.
temporal_attention
and
not
USE_TEMPORAL_TRANSFORMER
:
time_rel_pos_bias
=
self
.
time_rel_pos_bias
(
x
.
shape
[
2
],
device
=
x
.
device
)
else
:
time_rel_pos_bias
=
None
e
=
self
.
time_embed
(
sinusoidal_embedding
(
t
,
self
.
dim
))
e
=
e
.
repeat_interleave
(
repeats
=
f
,
dim
=
0
)
# context = y
context
=
y
.
repeat_interleave
(
repeats
=
f
,
dim
=
0
)
# always in shape (b f) c h w, except for temporal layer
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
# encoder
xs
=
[]
for
block
in
self
.
input_blocks
:
x
=
self
.
_forward_single
(
block
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
,
variant_info
=
variant_info
)
xs
.
append
(
x
)
# middle
for
block
in
self
.
middle_block
:
x
=
self
.
_forward_single
(
block
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
,
variant_info
=
variant_info
)
if
control
is
not
None
:
x
=
control
.
pop
()
+
x
# decoder
for
block
in
self
.
output_blocks
:
if
control
is
None
:
x
=
torch
.
cat
([
x
,
xs
.
pop
()],
dim
=
1
)
else
:
x
=
torch
.
cat
([
x
,
xs
.
pop
()
+
control
.
pop
()],
dim
=
1
)
x
=
self
.
_forward_single
(
block
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
,
reference
=
xs
[
-
1
]
if
len
(
xs
)
>
0
else
None
,
variant_info
=
variant_info
)
# head
x
=
self
.
out
(
x
)
# reshape back to (b c f h w)
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
batch
)
return
x
def
_forward_single
(
self
,
module
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
,
reference
=
None
,
variant_info
=
None
):
variant_info
=
None
# For Debug
if
isinstance
(
module
,
ResidualBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
x
.
contiguous
()
x
=
module
(
x
,
e
,
reference
)
elif
isinstance
(
module
,
ResBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
x
.
contiguous
()
x
=
module
(
x
,
e
,
self
.
batch
,
variant_info
)
elif
isinstance
(
module
,
SpatialTransformer
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
TemporalTransformer
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
x
=
module
(
x
,
context
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
CrossAttention
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
MemoryEfficientCrossAttention
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
BasicTransformerBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
FeedForward
):
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
Upsample
):
x
=
module
(
x
)
elif
isinstance
(
module
,
Downsample
):
x
=
module
(
x
)
elif
isinstance
(
module
,
Resample
):
x
=
module
(
x
,
reference
)
elif
isinstance
(
module
,
TemporalAttentionBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
x
=
module
(
x
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
TemporalAttentionMultiBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
x
=
module
(
x
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
InitTemporalConvBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
x
=
module
(
x
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
TemporalConvBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
x
=
module
(
x
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
nn
.
ModuleList
):
for
block
in
module
:
x
=
self
.
_forward_single
(
block
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
,
reference
,
variant_info
)
else
:
x
=
module
(
x
)
return
x
class
VideoControlNet
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
=
4
,
dim
=
320
,
y_dim
=
1024
,
context_dim
=
1024
,
out_dim
=
4
,
dim_mult
=
[
1
,
2
,
4
,
4
],
num_heads
=
8
,
head_dim
=
64
,
num_res_blocks
=
2
,
attn_scales
=
[
1
/
1
,
1
/
2
,
1
/
4
],
use_scale_shift_norm
=
True
,
dropout
=
0.1
,
temporal_attn_times
=
1
,
temporal_attention
=
True
,
use_checkpoint
=
True
,
use_image_dataset
=
False
,
use_fps_condition
=
False
,
use_sim_mask
=
False
,
training
=
False
,
inpainting
=
True
):
embed_dim
=
dim
*
4
num_heads
=
num_heads
if
num_heads
else
dim
//
32
super
(
VideoControlNet
,
self
).
__init__
()
self
.
in_dim
=
in_dim
self
.
dim
=
dim
self
.
y_dim
=
y_dim
self
.
context_dim
=
context_dim
self
.
embed_dim
=
embed_dim
self
.
out_dim
=
out_dim
self
.
dim_mult
=
dim_mult
# for temporal attention
self
.
num_heads
=
num_heads
# for spatial attention
self
.
head_dim
=
head_dim
self
.
num_res_blocks
=
num_res_blocks
self
.
attn_scales
=
attn_scales
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
temporal_attn_times
=
temporal_attn_times
self
.
temporal_attention
=
temporal_attention
self
.
use_checkpoint
=
use_checkpoint
self
.
use_image_dataset
=
use_image_dataset
self
.
use_fps_condition
=
use_fps_condition
self
.
use_sim_mask
=
use_sim_mask
self
.
training
=
training
self
.
inpainting
=
inpainting
use_linear_in_temporal
=
False
transformer_depth
=
1
disabled_sa
=
False
# params
enc_dims
=
[
dim
*
u
for
u
in
[
1
]
+
dim_mult
]
dec_dims
=
[
dim
*
u
for
u
in
[
dim_mult
[
-
1
]]
+
dim_mult
[::
-
1
]]
shortcut_dims
=
[]
scale
=
1.0
# CaptionEmbedder (new add)
# approx_gelu = lambda: nn.GELU(approximate="tanh")
# self.y_embedder = CaptionEmbedder(
# in_channels=4096,
# hidden_size=1024,
# uncond_prob=0.1,
# act_layer=approx_gelu,
# token_num=120,
# )
# embeddings
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
embed_dim
,
embed_dim
))
# self.hint_time_zero_linear = zero_module(nn.Linear(embed_dim, embed_dim))
# scale prompt
# self.scale_cond = nn.Sequential(
# nn.Linear(dim, embed_dim), nn.SiLU(),
# zero_module(nn.Linear(embed_dim, embed_dim)))
if
self
.
use_fps_condition
:
self
.
fps_embedding
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
embed_dim
,
embed_dim
))
nn
.
init
.
zeros_
(
self
.
fps_embedding
[
-
1
].
weight
)
nn
.
init
.
zeros_
(
self
.
fps_embedding
[
-
1
].
bias
)
# encoder
self
.
input_blocks
=
nn
.
ModuleList
()
init_block
=
nn
.
ModuleList
([
nn
.
Conv2d
(
self
.
in_dim
,
dim
,
3
,
padding
=
1
)])
# need an initial temporal attention?
if
temporal_attention
:
if
USE_TEMPORAL_TRANSFORMER
:
init_block
.
append
(
TemporalTransformer
(
dim
,
num_heads
,
head_dim
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_temporal
,
multiply_zero
=
use_image_dataset
,
is_ctrl
=
True
,))
else
:
init_block
.
append
(
TemporalAttentionMultiBlock
(
dim
,
num_heads
,
head_dim
,
rotary_emb
=
self
.
rotary_emb
,
temporal_attn_times
=
temporal_attn_times
,
use_image_dataset
=
use_image_dataset
))
self
.
input_blocks
.
append
(
init_block
)
self
.
zero_convs
=
nn
.
ModuleList
([
self
.
make_zero_conv
(
dim
)])
shortcut_dims
.
append
(
dim
)
for
i
,
(
in_dim
,
out_dim
)
in
enumerate
(
zip
(
enc_dims
[:
-
1
],
enc_dims
[
1
:])):
for
j
in
range
(
num_res_blocks
):
block
=
nn
.
ModuleList
([
ResBlock
(
in_dim
,
embed_dim
,
dropout
,
out_channels
=
out_dim
,
use_scale_shift_norm
=
False
,
use_image_dataset
=
use_image_dataset
,
)
])
if
scale
in
attn_scales
:
block
.
append
(
SpatialTransformer
(
out_dim
,
out_dim
//
head_dim
,
head_dim
,
depth
=
1
,
context_dim
=
self
.
context_dim
,
disable_self_attn
=
False
,
use_linear
=
True
,
is_ctrl
=
True
))
if
self
.
temporal_attention
:
if
USE_TEMPORAL_TRANSFORMER
:
block
.
append
(
TemporalTransformer
(
out_dim
,
out_dim
//
head_dim
,
head_dim
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_temporal
,
multiply_zero
=
use_image_dataset
,
is_ctrl
=
True
,))
else
:
block
.
append
(
TemporalAttentionMultiBlock
(
out_dim
,
num_heads
,
head_dim
,
rotary_emb
=
self
.
rotary_emb
,
use_image_dataset
=
use_image_dataset
,
use_sim_mask
=
use_sim_mask
,
temporal_attn_times
=
temporal_attn_times
))
in_dim
=
out_dim
self
.
input_blocks
.
append
(
block
)
self
.
zero_convs
.
append
(
self
.
make_zero_conv
(
out_dim
))
shortcut_dims
.
append
(
out_dim
)
# downsample
if
i
!=
len
(
dim_mult
)
-
1
and
j
==
num_res_blocks
-
1
:
downsample
=
Downsample
(
out_dim
,
True
,
dims
=
2
,
out_channels
=
out_dim
)
shortcut_dims
.
append
(
out_dim
)
scale
/=
2.0
self
.
input_blocks
.
append
(
downsample
)
self
.
zero_convs
.
append
(
self
.
make_zero_conv
(
out_dim
))
self
.
middle_block
=
nn
.
ModuleList
([
ResBlock
(
out_dim
,
embed_dim
,
dropout
,
use_scale_shift_norm
=
False
,
use_image_dataset
=
use_image_dataset
,
),
SpatialTransformer
(
out_dim
,
out_dim
//
head_dim
,
head_dim
,
depth
=
1
,
context_dim
=
self
.
context_dim
,
disable_self_attn
=
False
,
use_linear
=
True
,
is_ctrl
=
True
)
])
if
self
.
temporal_attention
:
if
USE_TEMPORAL_TRANSFORMER
:
self
.
middle_block
.
append
(
TemporalTransformer
(
out_dim
,
out_dim
//
head_dim
,
head_dim
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_temporal
,
multiply_zero
=
use_image_dataset
,
is_ctrl
=
True
,
))
else
:
self
.
middle_block
.
append
(
TemporalAttentionMultiBlock
(
out_dim
,
num_heads
,
head_dim
,
rotary_emb
=
self
.
rotary_emb
,
use_image_dataset
=
use_image_dataset
,
use_sim_mask
=
use_sim_mask
,
temporal_attn_times
=
temporal_attn_times
))
self
.
middle_block
.
append
(
ResBlock
(
out_dim
,
embed_dim
,
dropout
,
use_scale_shift_norm
=
False
))
self
.
middle_block_out
=
self
.
make_zero_conv
(
embed_dim
)
'''
add prompt
'''
add_dim
=
320
self
.
add_dim
=
add_dim
self
.
input_hint_block
=
zero_module
(
nn
.
Conv2d
(
4
,
add_dim
,
3
,
padding
=
1
))
def
make_zero_conv
(
self
,
in_channels
,
out_channels
=
None
):
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
return
TimestepEmbedSequential
(
zero_module
(
nn
.
Conv2d
(
in_channels
,
out_channels
,
1
,
padding
=
0
)))
def
forward
(
self
,
x
,
t
,
y
,
s_cond
=
None
,
hint
=
None
,
variant_info
=
None
,
t_hint
=
None
,
mask_cond
=
None
,
fps
=
None
,
video_mask
=
None
,
focus_present_mask
=
None
,
prob_focus_present
=
0.
,
mask_last_frame_num
=
0
):
batch
,
_
,
f
,
_
,
_
=
x
.
shape
device
=
x
.
device
self
.
batch
=
batch
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
if
mask_last_frame_num
>
0
:
focus_present_mask
=
None
video_mask
[
-
mask_last_frame_num
:]
=
False
else
:
focus_present_mask
=
default
(
focus_present_mask
,
lambda
:
prob_mask_like
(
(
batch
,
),
prob_focus_present
,
device
=
device
))
if
self
.
temporal_attention
and
not
USE_TEMPORAL_TRANSFORMER
:
time_rel_pos_bias
=
self
.
time_rel_pos_bias
(
x
.
shape
[
2
],
device
=
x
.
device
)
else
:
time_rel_pos_bias
=
None
if
hint
is
not
None
:
# add = x.new_zeros(batch, self.add_dim, f, h, w)
hint
=
rearrange
(
hint
,
'b c f h w -> (b f) c h w'
)
hint
=
self
.
input_hint_block
(
hint
)
# hint = rearrange(hint, '(b f) c h w -> b c f h w', b = batch)
e
=
self
.
time_embed
(
sinusoidal_embedding
(
t
,
self
.
dim
))
e
=
e
.
repeat_interleave
(
repeats
=
f
,
dim
=
0
)
context
=
y
.
repeat_interleave
(
repeats
=
f
,
dim
=
0
)
# always in shape (b f) c h w, except for temporal layer
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
# print('before x shape:', x.shape) [64, 320, 90, 160]
# print('hint shape:', hint.shape) [32, 320, 90, 160]
# encoder
xs
=
[]
for
module
,
zero_conv
in
zip
(
self
.
input_blocks
,
self
.
zero_convs
):
if
hint
is
not
None
:
for
block
in
module
:
x
=
self
.
_forward_single
(
block
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
,
variant_info
=
variant_info
)
if
not
isinstance
(
block
,
TemporalTransformer
):
if
hint
is
not
None
:
x
+=
hint
hint
=
None
else
:
x
=
self
.
_forward_single
(
module
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
,
variant_info
=
variant_info
)
xs
.
append
(
zero_conv
(
x
,
e
,
context
))
# middle
for
block
in
self
.
middle_block
:
x
=
self
.
_forward_single
(
block
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
,
variant_info
=
variant_info
)
xs
.
append
(
self
.
middle_block_out
(
x
,
e
,
context
))
return
xs
def
_forward_single
(
self
,
module
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
,
reference
=
None
,
variant_info
=
None
,):
# variant_info = None # For Debug
if
isinstance
(
module
,
ResidualBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
x
.
contiguous
()
x
=
module
(
x
,
e
,
reference
)
elif
isinstance
(
module
,
ResBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
x
.
contiguous
()
x
=
module
(
x
,
e
,
self
.
batch
,
variant_info
)
elif
isinstance
(
module
,
SpatialTransformer
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
TemporalTransformer
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
# print("x shape:", x.shape) # [2, 320, 32, 90, 160]
x
=
module
(
x
,
context
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
CrossAttention
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
MemoryEfficientCrossAttention
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
BasicTransformerBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
FeedForward
):
x
=
module
(
x
,
context
)
elif
isinstance
(
module
,
Upsample
):
x
=
module
(
x
)
elif
isinstance
(
module
,
Downsample
):
x
=
module
(
x
)
elif
isinstance
(
module
,
Resample
):
x
=
module
(
x
,
reference
)
elif
isinstance
(
module
,
TemporalAttentionBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
x
=
module
(
x
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
TemporalAttentionMultiBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
x
=
module
(
x
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
InitTemporalConvBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
x
=
module
(
x
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
TemporalConvBlock
):
module
=
checkpoint_wrapper
(
module
)
if
self
.
use_checkpoint
else
module
x
=
rearrange
(
x
,
'(b f) c h w -> b c f h w'
,
b
=
self
.
batch
)
x
=
module
(
x
)
x
=
rearrange
(
x
,
'b c f h w -> (b f) c h w'
)
elif
isinstance
(
module
,
nn
.
ModuleList
):
for
block
in
module
:
x
=
self
.
_forward_single
(
block
,
x
,
e
,
context
,
time_rel_pos_bias
,
focus_present_mask
,
video_mask
,
reference
,
variant_info
)
else
:
x
=
module
(
x
)
return
x
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def
forward
(
self
,
x
,
emb
,
context
=
None
):
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
SpatialTransformer
):
x
=
layer
(
x
,
context
)
else
:
x
=
layer
(
x
)
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment