Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
NVComposer
Commits
30af93f2
Commit
30af93f2
authored
Dec 26, 2024
by
chenpangpang
Browse files
feat: gpu初始提交
parent
68e98ab8
Pipeline
#2159
canceled with stages
Changes
66
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
7594 additions
and
0 deletions
+7594
-0
NVComposer/core/models/utils_diffusion.py
NVComposer/core/models/utils_diffusion.py
+186
-0
NVComposer/core/modules/attention.py
NVComposer/core/modules/attention.py
+710
-0
NVComposer/core/modules/attention_mv.py
NVComposer/core/modules/attention_mv.py
+316
-0
NVComposer/core/modules/attention_temporal.py
NVComposer/core/modules/attention_temporal.py
+1111
-0
NVComposer/core/modules/encoders/__init__.py
NVComposer/core/modules/encoders/__init__.py
+0
-0
NVComposer/core/modules/encoders/adapter.py
NVComposer/core/modules/encoders/adapter.py
+485
-0
NVComposer/core/modules/encoders/condition.py
NVComposer/core/modules/encoders/condition.py
+511
-0
NVComposer/core/modules/encoders/resampler.py
NVComposer/core/modules/encoders/resampler.py
+264
-0
NVComposer/core/modules/networks/ae_modules.py
NVComposer/core/modules/networks/ae_modules.py
+1023
-0
NVComposer/core/modules/networks/unet_modules.py
NVComposer/core/modules/networks/unet_modules.py
+1047
-0
NVComposer/core/modules/position_encoding.py
NVComposer/core/modules/position_encoding.py
+97
-0
NVComposer/core/modules/x_transformer.py
NVComposer/core/modules/x_transformer.py
+679
-0
NVComposer/main/evaluation/funcs.py
NVComposer/main/evaluation/funcs.py
+295
-0
NVComposer/main/evaluation/pose_interpolation.py
NVComposer/main/evaluation/pose_interpolation.py
+215
-0
NVComposer/main/evaluation/utils_eval.py
NVComposer/main/evaluation/utils_eval.py
+26
-0
NVComposer/main/utils_data.py
NVComposer/main/utils_data.py
+164
-0
NVComposer/requirements.txt
NVComposer/requirements.txt
+12
-0
NVComposer/utils/constants.py
NVComposer/utils/constants.py
+2
-0
NVComposer/utils/load_weigths.py
NVComposer/utils/load_weigths.py
+252
-0
NVComposer/utils/lr_scheduler.py
NVComposer/utils/lr_scheduler.py
+199
-0
No files found.
NVComposer/core/models/utils_diffusion.py
0 → 100755
View file @
30af93f2
import
math
import
numpy
as
np
import
torch
from
einops
import
repeat
def
timestep_embedding
(
time_steps
,
dim
,
max_period
=
10000
,
repeat_only
=
False
):
"""
Create sinusoidal timestep embeddings.
:param time_steps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if
not
repeat_only
:
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
time_steps
.
device
)
args
=
time_steps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
(
[
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
else
:
embedding
=
repeat
(
time_steps
,
"b -> b d"
,
d
=
dim
)
return
embedding
def
make_beta_schedule
(
schedule
,
n_timestep
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
cosine_s
=
8e-3
):
if
schedule
==
"linear"
:
betas
=
(
torch
.
linspace
(
linear_start
**
0.5
,
linear_end
**
0.5
,
n_timestep
,
dtype
=
torch
.
float64
)
**
2
)
elif
schedule
==
"cosine"
:
time_steps
=
(
torch
.
arange
(
n_timestep
+
1
,
dtype
=
torch
.
float64
)
/
n_timestep
+
cosine_s
)
alphas
=
time_steps
/
(
1
+
cosine_s
)
*
np
.
pi
/
2
alphas
=
torch
.
cos
(
alphas
).
pow
(
2
)
alphas
=
alphas
/
alphas
[
0
]
betas
=
1
-
alphas
[
1
:]
/
alphas
[:
-
1
]
betas
=
np
.
clip
(
betas
,
a_min
=
0
,
a_max
=
0.999
)
elif
schedule
==
"sqrt_linear"
:
betas
=
torch
.
linspace
(
linear_start
,
linear_end
,
n_timestep
,
dtype
=
torch
.
float64
)
elif
schedule
==
"sqrt"
:
betas
=
(
torch
.
linspace
(
linear_start
,
linear_end
,
n_timestep
,
dtype
=
torch
.
float64
)
**
0.5
)
else
:
raise
ValueError
(
f
"schedule '
{
schedule
}
' unknown."
)
return
betas
.
numpy
()
def
make_ddim_time_steps
(
ddim_discr_method
,
num_ddim_time_steps
,
num_ddpm_time_steps
,
verbose
=
True
):
if
ddim_discr_method
==
"uniform"
:
c
=
num_ddpm_time_steps
//
num_ddim_time_steps
ddim_time_steps
=
np
.
asarray
(
list
(
range
(
0
,
num_ddpm_time_steps
,
c
)))
steps_out
=
ddim_time_steps
+
1
elif
ddim_discr_method
==
"quad"
:
ddim_time_steps
=
(
(
np
.
linspace
(
0
,
np
.
sqrt
(
num_ddpm_time_steps
*
0.8
),
num_ddim_time_steps
))
**
2
).
astype
(
int
)
steps_out
=
ddim_time_steps
+
1
elif
ddim_discr_method
==
"uniform_trailing"
:
c
=
num_ddpm_time_steps
/
num_ddim_time_steps
ddim_time_steps
=
np
.
flip
(
np
.
round
(
np
.
arange
(
num_ddpm_time_steps
,
0
,
-
c
))
).
astype
(
np
.
int64
)
steps_out
=
ddim_time_steps
-
1
else
:
raise
NotImplementedError
(
f
'There is no ddim discretization method called "
{
ddim_discr_method
}
"'
)
# assert ddim_time_steps.shape[0] == num_ddim_time_steps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
if
verbose
:
print
(
f
"Selected time_steps for ddim sampler:
{
steps_out
}
"
)
return
steps_out
def
make_ddim_sampling_parameters
(
alphacums
,
ddim_time_steps
,
eta
,
verbose
=
True
):
# select alphas for computing the variance schedule
# print(f'ddim_time_steps={ddim_time_steps}, len_alphacums={len(alphacums)}')
alphas
=
alphacums
[
ddim_time_steps
]
alphas_prev
=
np
.
asarray
([
alphacums
[
0
]]
+
alphacums
[
ddim_time_steps
[:
-
1
]].
tolist
())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas
=
eta
*
np
.
sqrt
(
(
1
-
alphas_prev
)
/
(
1
-
alphas
)
*
(
1
-
alphas
/
alphas_prev
)
)
if
verbose
:
print
(
f
"Selected alphas for ddim sampler: a_t:
{
alphas
}
; a_(t-1):
{
alphas_prev
}
"
)
print
(
f
"For the chosen value of eta, which is
{
eta
}
, "
f
"this results in the following sigma_t schedule for ddim sampler
{
sigmas
}
"
)
return
sigmas
,
alphas
,
alphas_prev
def
betas_for_alpha_bar
(
num_diffusion_time_steps
,
alpha_bar
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_time_steps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas
=
[]
for
i
in
range
(
num_diffusion_time_steps
):
t1
=
i
/
num_diffusion_time_steps
t2
=
(
i
+
1
)
/
num_diffusion_time_steps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
np
.
array
(
betas
)
def
rescale_zero_terminal_snr
(
betas
):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`numpy.ndarray`):
the betas that the scheduler is being initialized with.
Returns:
`numpy.ndarray`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas
=
1.0
-
betas
alphas_cumprod
=
np
.
cumprod
(
alphas
,
axis
=
0
)
alphas_bar_sqrt
=
np
.
sqrt
(
alphas_cumprod
)
# Store old values.
alphas_bar_sqrt_0
=
alphas_bar_sqrt
[
0
].
copy
()
alphas_bar_sqrt_T
=
alphas_bar_sqrt
[
-
1
].
copy
()
# Shift so the last timestep is zero.
alphas_bar_sqrt
-=
alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt
*=
alphas_bar_sqrt_0
/
(
alphas_bar_sqrt_0
-
alphas_bar_sqrt_T
)
# Convert alphas_bar_sqrt to betas
alphas_bar
=
alphas_bar_sqrt
**
2
# Revert sqrt
alphas
=
alphas_bar
[
1
:]
/
alphas_bar
[:
-
1
]
# Revert cumprod
alphas
=
np
.
concatenate
([
alphas_bar
[
0
:
1
],
alphas
])
betas
=
1
-
alphas
return
betas
def
rescale_noise_cfg
(
noise_cfg
,
noise_pred_text
,
guidance_rescale
=
0.0
):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text
=
noise_pred_text
.
std
(
dim
=
list
(
range
(
1
,
noise_pred_text
.
ndim
)),
keepdim
=
True
)
std_cfg
=
noise_cfg
.
std
(
dim
=
list
(
range
(
1
,
noise_cfg
.
ndim
)),
keepdim
=
True
)
factor
=
guidance_rescale
*
(
std_text
/
std_cfg
)
+
(
1
-
guidance_rescale
)
return
noise_cfg
*
factor
NVComposer/core/modules/attention.py
0 → 100755
View file @
30af93f2
import
torch
from
torch
import
nn
,
einsum
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
functools
import
partial
try
:
import
xformers
import
xformers.ops
XFORMERS_IS_AVAILBLE
=
True
except
:
XFORMERS_IS_AVAILBLE
=
False
from
core.common
import
(
gradient_checkpoint
,
exists
,
default
,
)
from
core.basics
import
zero_module
class
RelativePosition
(
nn
.
Module
):
def
__init__
(
self
,
num_units
,
max_relative_position
):
super
().
__init__
()
self
.
num_units
=
num_units
self
.
max_relative_position
=
max_relative_position
self
.
embeddings_table
=
nn
.
Parameter
(
torch
.
Tensor
(
max_relative_position
*
2
+
1
,
num_units
)
)
nn
.
init
.
xavier_uniform_
(
self
.
embeddings_table
)
def
forward
(
self
,
length_q
,
length_k
):
device
=
self
.
embeddings_table
.
device
range_vec_q
=
torch
.
arange
(
length_q
,
device
=
device
)
range_vec_k
=
torch
.
arange
(
length_k
,
device
=
device
)
distance_mat
=
range_vec_k
[
None
,
:]
-
range_vec_q
[:,
None
]
distance_mat_clipped
=
torch
.
clamp
(
distance_mat
,
-
self
.
max_relative_position
,
self
.
max_relative_position
)
final_mat
=
distance_mat_clipped
+
self
.
max_relative_position
final_mat
=
final_mat
.
long
()
embeddings
=
self
.
embeddings_table
[
final_mat
]
return
embeddings
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
,
relative_position
=
False
,
temporal_length
=
None
,
video_length
=
None
,
image_cross_attention
=
False
,
image_cross_attention_scale
=
1.0
,
image_cross_attention_scale_learnable
=
False
,
text_context_len
=
77
,
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
)
)
self
.
relative_position
=
relative_position
if
self
.
relative_position
:
assert
temporal_length
is
not
None
self
.
relative_position_k
=
RelativePosition
(
num_units
=
dim_head
,
max_relative_position
=
temporal_length
)
self
.
relative_position_v
=
RelativePosition
(
num_units
=
dim_head
,
max_relative_position
=
temporal_length
)
else
:
# only used for spatial attention, while NOT for temporal attention
if
XFORMERS_IS_AVAILBLE
and
temporal_length
is
None
:
self
.
forward
=
self
.
efficient_forward
self
.
video_length
=
video_length
self
.
image_cross_attention
=
image_cross_attention
self
.
image_cross_attention_scale
=
image_cross_attention_scale
self
.
text_context_len
=
text_context_len
self
.
image_cross_attention_scale_learnable
=
(
image_cross_attention_scale_learnable
)
if
self
.
image_cross_attention
:
self
.
to_k_ip
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v_ip
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
if
image_cross_attention_scale_learnable
:
self
.
register_parameter
(
"alpha"
,
nn
.
Parameter
(
torch
.
tensor
(
0.0
)))
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
spatial_self_attn
=
context
is
None
k_ip
,
v_ip
,
out_ip
=
None
,
None
,
None
h
=
self
.
heads
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
if
self
.
image_cross_attention
and
not
spatial_self_attn
:
context
,
context_image
=
(
context
[:,
:
self
.
text_context_len
,
:],
context
[:,
self
.
text_context_len
:,
:],
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
k_ip
=
self
.
to_k_ip
(
context_image
)
v_ip
=
self
.
to_v_ip
(
context_image
)
else
:
if
not
spatial_self_attn
:
context
=
context
[:,
:
self
.
text_context_len
,
:]
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
"b n (h d) -> (b h) n d"
,
h
=
h
),
(
q
,
k
,
v
))
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
self
.
relative_position
:
len_q
,
len_k
,
len_v
=
q
.
shape
[
1
],
k
.
shape
[
1
],
v
.
shape
[
1
]
k2
=
self
.
relative_position_k
(
len_q
,
len_k
)
sim2
=
einsum
(
"b t d, t s d -> b t s"
,
q
,
k2
)
*
self
.
scale
sim
+=
sim2
del
k
if
exists
(
mask
):
# feasible for causal attention mask only
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
mask
=
repeat
(
mask
,
"b i j -> (b h) i j"
,
h
=
h
)
sim
.
masked_fill_
(
~
(
mask
>
0.5
),
max_neg_value
)
# attention, what we cannot get enough of
sim
=
sim
.
softmax
(
dim
=-
1
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
sim
,
v
)
if
self
.
relative_position
:
v2
=
self
.
relative_position_v
(
len_q
,
len_v
)
out2
=
einsum
(
"b t s, t s d -> b t d"
,
sim
,
v2
)
out
+=
out2
out
=
rearrange
(
out
,
"(b h) n d -> b n (h d)"
,
h
=
h
)
# for image cross-attention
if
k_ip
is
not
None
:
k_ip
,
v_ip
=
map
(
lambda
t
:
rearrange
(
t
,
"b n (h d) -> (b h) n d"
,
h
=
h
),
(
k_ip
,
v_ip
)
)
sim_ip
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k_ip
)
*
self
.
scale
del
k_ip
sim_ip
=
sim_ip
.
softmax
(
dim
=-
1
)
out_ip
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
sim_ip
,
v_ip
)
out_ip
=
rearrange
(
out_ip
,
"(b h) n d -> b n (h d)"
,
h
=
h
)
if
out_ip
is
not
None
:
if
self
.
image_cross_attention_scale_learnable
:
out
=
out
+
self
.
image_cross_attention_scale
*
out_ip
*
(
torch
.
tanh
(
self
.
alpha
)
+
1
)
else
:
out
=
out
+
self
.
image_cross_attention_scale
*
out_ip
return
self
.
to_out
(
out
)
def
efficient_forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
spatial_self_attn
=
context
is
None
k_ip
,
v_ip
,
out_ip
=
None
,
None
,
None
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
if
self
.
image_cross_attention
and
not
spatial_self_attn
:
context
,
context_image
=
(
context
[:,
:
self
.
text_context_len
,
:],
context
[:,
self
.
text_context_len
:,
:],
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
k_ip
=
self
.
to_k_ip
(
context_image
)
v_ip
=
self
.
to_v_ip
(
context_image
)
else
:
if
not
spatial_self_attn
:
context
=
context
[:,
:
self
.
text_context_len
,
:]
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
b
,
_
,
_
=
q
.
shape
q
,
k
,
v
=
map
(
lambda
t
:
t
.
unsqueeze
(
3
)
.
reshape
(
b
,
t
.
shape
[
1
],
self
.
heads
,
self
.
dim_head
)
.
permute
(
0
,
2
,
1
,
3
)
.
reshape
(
b
*
self
.
heads
,
t
.
shape
[
1
],
self
.
dim_head
)
.
contiguous
(),
(
q
,
k
,
v
),
)
# actually compute the attention, what we cannot get enough of
out
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
,
op
=
None
)
# for image cross-attention
if
k_ip
is
not
None
:
k_ip
,
v_ip
=
map
(
lambda
t
:
t
.
unsqueeze
(
3
)
.
reshape
(
b
,
t
.
shape
[
1
],
self
.
heads
,
self
.
dim_head
)
.
permute
(
0
,
2
,
1
,
3
)
.
reshape
(
b
*
self
.
heads
,
t
.
shape
[
1
],
self
.
dim_head
)
.
contiguous
(),
(
k_ip
,
v_ip
),
)
out_ip
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k_ip
,
v_ip
,
attn_bias
=
None
,
op
=
None
)
out_ip
=
(
out_ip
.
unsqueeze
(
0
)
.
reshape
(
b
,
self
.
heads
,
out
.
shape
[
1
],
self
.
dim_head
)
.
permute
(
0
,
2
,
1
,
3
)
.
reshape
(
b
,
out
.
shape
[
1
],
self
.
heads
*
self
.
dim_head
)
)
if
exists
(
mask
):
raise
NotImplementedError
out
=
(
out
.
unsqueeze
(
0
)
.
reshape
(
b
,
self
.
heads
,
out
.
shape
[
1
],
self
.
dim_head
)
.
permute
(
0
,
2
,
1
,
3
)
.
reshape
(
b
,
out
.
shape
[
1
],
self
.
heads
*
self
.
dim_head
)
)
if
out_ip
is
not
None
:
if
self
.
image_cross_attention_scale_learnable
:
out
=
out
+
self
.
image_cross_attention_scale
*
out_ip
*
(
torch
.
tanh
(
self
.
alpha
)
+
1
)
else
:
out
=
out
+
self
.
image_cross_attention_scale
*
out_ip
return
self
.
to_out
(
out
)
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
disable_self_attn
=
False
,
attention_cls
=
None
,
video_length
=
None
,
image_cross_attention
=
False
,
image_cross_attention_scale
=
1.0
,
image_cross_attention_scale_learnable
=
False
,
text_context_len
=
77
,
enable_lora
=
False
,
):
super
().
__init__
()
attn_cls
=
CrossAttention
if
attention_cls
is
None
else
attention_cls
self
.
disable_self_attn
=
disable_self_attn
self
.
attn1
=
attn_cls
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
if
self
.
disable_self_attn
else
None
,
)
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
,
enable_lora
=
enable_lora
)
self
.
attn2
=
attn_cls
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
video_length
=
video_length
,
image_cross_attention
=
image_cross_attention
,
image_cross_attention_scale
=
image_cross_attention_scale
,
image_cross_attention_scale_learnable
=
image_cross_attention_scale_learnable
,
text_context_len
=
text_context_len
,
)
self
.
image_cross_attention
=
image_cross_attention
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
checkpoint
=
checkpoint
self
.
enable_lora
=
enable_lora
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
with_lora
=
False
,
**
kwargs
):
# implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
# should not be (x), otherwise *input_tuple will decouple x into multiple arguments
input_tuple
=
(
x
,)
if
context
is
not
None
:
input_tuple
=
(
x
,
context
)
if
mask
is
not
None
:
_forward
=
partial
(
self
.
_forward
,
mask
=
None
,
with_lora
=
with_lora
)
else
:
_forward
=
partial
(
self
.
_forward
,
mask
=
mask
,
with_lora
=
with_lora
)
return
gradient_checkpoint
(
_forward
,
input_tuple
,
self
.
parameters
(),
self
.
checkpoint
)
def
_forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
with_lora
=
False
):
x
=
(
self
.
attn1
(
self
.
norm1
(
x
),
context
=
context
if
self
.
disable_self_attn
else
None
,
mask
=
mask
,
)
+
x
)
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context
,
mask
=
mask
)
+
x
x
=
self
.
ff
(
self
.
norm3
(
x
),
with_lora
=
with_lora
)
+
x
return
x
class
SpatialTransformer
(
nn
.
Module
):
"""
Transformer block for image-like data in spatial axis.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
,
use_checkpoint
=
True
,
disable_self_attn
=
False
,
use_linear
=
False
,
video_length
=
None
,
image_cross_attention
=
False
,
image_cross_attention_scale_learnable
=
False
,
enable_lora
=
False
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
if
not
use_linear
:
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
else
:
self
.
proj_in
=
nn
.
Linear
(
in_channels
,
inner_dim
)
self
.
enable_lora
=
enable_lora
attention_cls
=
None
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_self_attn
,
checkpoint
=
use_checkpoint
,
attention_cls
=
attention_cls
,
video_length
=
video_length
,
image_cross_attention
=
image_cross_attention
,
image_cross_attention_scale_learnable
=
image_cross_attention_scale_learnable
,
enable_lora
=
self
.
enable_lora
,
)
for
d
in
range
(
depth
)
]
)
if
not
use_linear
:
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
)
else
:
self
.
proj_out
=
zero_module
(
nn
.
Linear
(
inner_dim
,
in_channels
))
self
.
use_linear
=
use_linear
def
forward
(
self
,
x
,
context
=
None
,
with_lora
=
False
,
**
kwargs
):
b
,
c
,
h
,
w
=
x
.
shape
x_in
=
x
x
=
self
.
norm
(
x
)
if
not
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
x
=
rearrange
(
x
,
"b c h w -> b (h w) c"
).
contiguous
()
if
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
x
=
block
(
x
,
context
=
context
,
with_lora
=
with_lora
,
**
kwargs
)
if
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
x
=
rearrange
(
x
,
"b (h w) c -> b c h w"
,
h
=
h
,
w
=
w
).
contiguous
()
if
not
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
class
TemporalTransformer
(
nn
.
Module
):
"""
Transformer block for image-like data in temporal axis.
First, reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
,
use_checkpoint
=
True
,
use_linear
=
False
,
only_self_att
=
True
,
causal_attention
=
False
,
causal_block_size
=
1
,
relative_position
=
False
,
temporal_length
=
None
,
use_extra_spatial_temporal_self_attention
=
False
,
enable_lora
=
False
,
full_spatial_temporal_attention
=
False
,
enhance_multi_view_correspondence
=
False
,
):
super
().
__init__
()
self
.
only_self_att
=
only_self_att
self
.
relative_position
=
relative_position
self
.
causal_attention
=
causal_attention
self
.
causal_block_size
=
causal_block_size
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
self
.
proj_in
=
nn
.
Conv1d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
if
not
use_linear
:
self
.
proj_in
=
nn
.
Conv1d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
else
:
self
.
proj_in
=
nn
.
Linear
(
in_channels
,
inner_dim
)
if
relative_position
:
assert
temporal_length
is
not
None
attention_cls
=
partial
(
CrossAttention
,
relative_position
=
True
,
temporal_length
=
temporal_length
)
else
:
attention_cls
=
partial
(
CrossAttention
,
temporal_length
=
temporal_length
)
if
self
.
causal_attention
:
assert
temporal_length
is
not
None
self
.
mask
=
torch
.
tril
(
torch
.
ones
([
1
,
temporal_length
,
temporal_length
]))
if
self
.
only_self_att
:
context_dim
=
None
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
,
attention_cls
=
attention_cls
,
checkpoint
=
use_checkpoint
,
enable_lora
=
enable_lora
,
)
for
d
in
range
(
depth
)
]
)
if
not
use_linear
:
self
.
proj_out
=
zero_module
(
nn
.
Conv1d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
)
else
:
self
.
proj_out
=
zero_module
(
nn
.
Linear
(
inner_dim
,
in_channels
))
self
.
use_linear
=
use_linear
self
.
use_extra_spatial_temporal_self_attention
=
(
use_extra_spatial_temporal_self_attention
)
if
use_extra_spatial_temporal_self_attention
:
from
core.modules.attention_mv
import
MultiViewSelfAttentionTransformer
self
.
extra_spatial_time_self_attention
=
MultiViewSelfAttentionTransformer
(
in_channels
=
in_channels
,
n_heads
=
n_heads
,
d_head
=
d_head
,
num_views
=
temporal_length
,
depth
=
depth
,
use_linear
=
use_linear
,
use_checkpoint
=
use_checkpoint
,
full_spatial_temporal_attention
=
full_spatial_temporal_attention
,
enhance_multi_view_correspondence
=
enhance_multi_view_correspondence
,
)
def
forward
(
self
,
x
,
context
=
None
,
with_lora
=
False
,
time_steps
=
None
):
b
,
c
,
t
,
h
,
w
=
x
.
shape
x_in
=
x
x
=
self
.
norm
(
x
)
x
=
rearrange
(
x
,
"b c t h w -> (b h w) c t"
).
contiguous
()
if
not
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
x
=
rearrange
(
x
,
"bhw c t -> bhw t c"
).
contiguous
()
if
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
temp_mask
=
None
if
self
.
causal_attention
:
# slice the from mask map
temp_mask
=
self
.
mask
[:,
:
t
,
:
t
].
to
(
x
.
device
)
if
temp_mask
is
not
None
:
mask
=
temp_mask
.
to
(
x
.
device
)
mask
=
repeat
(
mask
,
"l i j -> (l bhw) i j"
,
bhw
=
b
*
h
*
w
)
else
:
mask
=
None
if
self
.
only_self_att
:
# note: if no context is given, cross-attention defaults to self-attention
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
x
=
block
(
x
,
mask
=
mask
,
with_lora
=
with_lora
)
x
=
rearrange
(
x
,
"(b hw) t c -> b hw t c"
,
b
=
b
).
contiguous
()
else
:
x
=
rearrange
(
x
,
"(b hw) t c -> b hw t c"
,
b
=
b
).
contiguous
()
context
=
rearrange
(
context
,
"(b t) l con -> b t l con"
,
t
=
t
).
contiguous
()
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
for
j
in
range
(
b
):
context_j
=
repeat
(
context
[
j
],
"t l con -> (t r) l con"
,
r
=
(
h
*
w
)
//
t
,
t
=
t
).
contiguous
()
# note: causal mask will not applied in cross-attention case
x
[
j
]
=
block
(
x
[
j
],
context
=
context_j
,
with_lora
=
with_lora
)
if
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
x
=
rearrange
(
x
,
"b (h w) t c -> b c t h w"
,
h
=
h
,
w
=
w
).
contiguous
()
if
not
self
.
use_linear
:
x
=
rearrange
(
x
,
"b hw t c -> (b hw) c t"
).
contiguous
()
x
=
self
.
proj_out
(
x
)
x
=
rearrange
(
x
,
"(b h w) c t -> b c t h w"
,
b
=
b
,
h
=
h
,
w
=
w
).
contiguous
()
res
=
x
+
x_in
if
self
.
use_extra_spatial_temporal_self_attention
:
res
=
rearrange
(
res
,
"b c t h w -> (b t) c h w"
,
b
=
b
,
h
=
h
,
w
=
w
).
contiguous
()
res
=
self
.
extra_spatial_time_self_attention
(
res
,
time_steps
=
time_steps
)
res
=
rearrange
(
res
,
"(b t) c h w -> b c t h w"
,
b
=
b
,
h
=
h
,
w
=
w
).
contiguous
()
return
res
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.0
,
enable_lora
=
False
,
lora_rank
=
32
,
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
(
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
())
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
)
)
self
.
enable_lora
=
enable_lora
self
.
lora_rank
=
lora_rank
self
.
lora_alpha
=
16
if
self
.
enable_lora
:
assert
(
self
.
lora_rank
is
not
None
),
"`lora_rank` must be given when `enable_lora` is True."
assert
(
0
<
self
.
lora_rank
<
min
(
dim
,
dim_out
)
),
f
"`lora_rank` must be range [0, min(inner_dim=
{
inner_dim
}
, dim_out=
{
dim_out
}
)], but got
{
self
.
lora_rank
}
."
self
.
lora_a
=
nn
.
Parameter
(
torch
.
zeros
((
inner_dim
,
self
.
lora_rank
),
requires_grad
=
True
)
)
self
.
lora_b
=
nn
.
Parameter
(
torch
.
zeros
((
self
.
lora_rank
,
dim_out
),
requires_grad
=
True
)
)
self
.
scaling
=
self
.
lora_alpha
/
self
.
lora_rank
def
forward
(
self
,
x
,
with_lora
=
False
):
if
with_lora
:
projected_x
=
self
.
net
[
1
](
self
.
net
[
0
](
x
))
lora_x
=
(
torch
.
matmul
(
projected_x
,
torch
.
matmul
(
self
.
lora_a
,
self
.
lora_b
))
*
self
.
scaling
)
original_x
=
self
.
net
[
2
](
projected_x
)
return
original_x
+
lora_x
else
:
return
self
.
net
(
x
)
class
LinearAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
().
__init__
()
self
.
heads
=
heads
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
q
,
k
,
v
=
rearrange
(
qkv
,
"b (qkv heads c) h w -> qkv b heads c (h w)"
,
heads
=
self
.
heads
,
qkv
=
3
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
out
=
rearrange
(
out
,
"b heads c (h w) -> b (heads c) h w"
,
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
return
self
.
to_out
(
out
)
class
SpatialSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
rearrange
(
q
,
"b c h w -> b (h w) c"
)
k
=
rearrange
(
k
,
"b c h w -> b c (h w)"
)
w_
=
torch
.
einsum
(
"bij,bjk->bik"
,
q
,
k
)
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
rearrange
(
v
,
"b c h w -> b c (h w)"
)
w_
=
rearrange
(
w_
,
"b i j -> b j i"
)
h_
=
torch
.
einsum
(
"bij,bjk->bik"
,
v
,
w_
)
h_
=
rearrange
(
h_
,
"b c (h w) -> b c h w"
,
h
=
h
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
NVComposer/core/modules/attention_mv.py
0 → 100755
View file @
30af93f2
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch
import
nn
from
core.common
import
gradient_checkpoint
try
:
import
xformers
import
xformers.ops
XFORMERS_IS_AVAILBLE
=
True
except
:
XFORMERS_IS_AVAILBLE
=
False
print
(
f
"XFORMERS_IS_AVAILBLE:
{
XFORMERS_IS_AVAILBLE
}
"
)
def
get_group_norm_layer
(
in_channels
):
if
in_channels
<
32
:
if
in_channels
%
2
==
0
:
num_groups
=
in_channels
//
2
else
:
num_groups
=
in_channels
else
:
num_groups
=
32
return
torch
.
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if
dims
==
1
:
return
nn
.
Conv1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.0
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
if
dim_out
is
None
:
dim_out
=
dim
project_in
=
(
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
())
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
class
SpatialTemporalAttention
(
nn
.
Module
):
"""Uses xformers to implement efficient epipolar masking for cross-attention between views."""
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
if
context_dim
is
None
:
context_dim
=
query_dim
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
)
)
self
.
attention_op
=
None
def
forward
(
self
,
x
,
context
=
None
,
enhance_multi_view_correspondence
=
False
):
q
=
self
.
to_q
(
x
)
if
context
is
None
:
context
=
x
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
b
,
_
,
_
=
q
.
shape
q
,
k
,
v
=
map
(
lambda
t
:
t
.
unsqueeze
(
3
)
.
reshape
(
b
,
t
.
shape
[
1
],
self
.
heads
,
self
.
dim_head
)
.
permute
(
0
,
2
,
1
,
3
)
.
reshape
(
b
*
self
.
heads
,
t
.
shape
[
1
],
self
.
dim_head
)
.
contiguous
(),
(
q
,
k
,
v
),
)
if
enhance_multi_view_correspondence
:
with
torch
.
no_grad
():
normalized_x
=
torch
.
nn
.
functional
.
normalize
(
x
.
detach
(),
p
=
2
,
dim
=-
1
)
cosine_sim_map
=
torch
.
bmm
(
normalized_x
,
normalized_x
.
transpose
(
-
1
,
-
2
))
attn_bias
=
torch
.
where
(
cosine_sim_map
>
0.0
,
0.0
,
-
1e9
).
to
(
dtype
=
q
.
dtype
)
attn_bias
=
rearrange
(
attn_bias
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
heads
,
-
1
,
-
1
),
"b h d1 d2 -> (b h) d1 d2"
,
).
detach
()
else
:
attn_bias
=
None
out
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
attn_bias
,
op
=
self
.
attention_op
)
out
=
(
out
.
unsqueeze
(
0
)
.
reshape
(
b
,
self
.
heads
,
out
.
shape
[
1
],
self
.
dim_head
)
.
permute
(
0
,
2
,
1
,
3
)
.
reshape
(
b
,
out
.
shape
[
1
],
self
.
heads
*
self
.
dim_head
)
)
del
q
,
k
,
v
,
attn_bias
return
self
.
to_out
(
out
)
class
MultiViewSelfAttentionTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
gated_ff
=
True
,
use_checkpoint
=
True
,
full_spatial_temporal_attention
=
False
,
enhance_multi_view_correspondence
=
False
,
):
super
().
__init__
()
attn_cls
=
SpatialTemporalAttention
# self.self_attention_only = self_attention_only
self
.
attn1
=
attn_cls
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
context_dim
=
None
,
)
# is a self-attention if not self.disable_self_attn
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
if
enhance_multi_view_correspondence
:
# Zero initalization when MVCorr is enabled.
zero_module_fn
=
zero_module
else
:
def
zero_module_fn
(
x
):
return
x
self
.
attn2
=
zero_module_fn
(
attn_cls
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
context_dim
=
None
,
)
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
use_checkpoint
=
use_checkpoint
self
.
full_spatial_temporal_attention
=
full_spatial_temporal_attention
self
.
enhance_multi_view_correspondence
=
enhance_multi_view_correspondence
def
forward
(
self
,
x
,
time_steps
=
None
):
return
gradient_checkpoint
(
self
.
many_stream_forward
,
(
x
,
time_steps
),
None
,
flag
=
self
.
use_checkpoint
)
def
many_stream_forward
(
self
,
x
,
time_steps
=
None
):
n
,
v
,
hw
=
x
.
shape
[:
3
]
x
=
rearrange
(
x
,
"n v hw c -> n (v hw) c"
)
x
=
(
self
.
attn1
(
self
.
norm1
(
x
),
context
=
None
,
enhance_multi_view_correspondence
=
False
)
+
x
)
if
not
self
.
full_spatial_temporal_attention
:
x
=
rearrange
(
x
,
"n (v hw) c -> n v hw c"
,
v
=
v
)
x
=
rearrange
(
x
,
"n v hw c -> (n v) hw c"
)
x
=
(
self
.
attn2
(
self
.
norm2
(
x
),
context
=
None
,
enhance_multi_view_correspondence
=
self
.
enhance_multi_view_correspondence
and
hw
<=
256
,
)
+
x
)
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
if
self
.
full_spatial_temporal_attention
:
x
=
rearrange
(
x
,
"n (v hw) c -> n v hw c"
,
v
=
v
)
else
:
x
=
rearrange
(
x
,
"(n v) hw c -> n v hw c"
,
v
=
v
)
return
x
class
MultiViewSelfAttentionTransformer
(
nn
.
Module
):
"""Spatial Transformer block with post init to add cross attn."""
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
num_views
,
depth
=
1
,
dropout
=
0.0
,
use_linear
=
True
,
use_checkpoint
=
True
,
zero_out_initialization
=
True
,
full_spatial_temporal_attention
=
False
,
enhance_multi_view_correspondence
=
False
,
):
super
().
__init__
()
self
.
num_views
=
num_views
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
get_group_norm_layer
(
in_channels
)
if
not
use_linear
:
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
else
:
self
.
proj_in
=
nn
.
Linear
(
in_channels
,
inner_dim
)
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
MultiViewSelfAttentionTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
use_checkpoint
=
use_checkpoint
,
full_spatial_temporal_attention
=
full_spatial_temporal_attention
,
enhance_multi_view_correspondence
=
enhance_multi_view_correspondence
,
)
for
d
in
range
(
depth
)
]
)
self
.
zero_out_initialization
=
zero_out_initialization
if
zero_out_initialization
:
_zero_func
=
zero_module
else
:
def
_zero_func
(
x
):
return
x
if
not
use_linear
:
self
.
proj_out
=
_zero_func
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
)
else
:
self
.
proj_out
=
_zero_func
(
nn
.
Linear
(
inner_dim
,
in_channels
))
self
.
use_linear
=
use_linear
def
forward
(
self
,
x
,
time_steps
=
None
):
# x: bt c h w
_
,
c
,
h
,
w
=
x
.
shape
n_views
=
self
.
num_views
x_in
=
x
x
=
self
.
norm
(
x
)
x
=
rearrange
(
x
,
"(n v) c h w -> n v (h w) c"
,
v
=
n_views
)
if
self
.
use_linear
:
x
=
rearrange
(
x
,
"n v x c -> (n v) x c"
)
x
=
self
.
proj_in
(
x
)
x
=
rearrange
(
x
,
"(n v) x c -> n v x c"
,
v
=
n_views
)
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
x
=
block
(
x
,
time_steps
=
time_steps
)
if
self
.
use_linear
:
x
=
rearrange
(
x
,
"n v x c -> (n v) x c"
)
x
=
self
.
proj_out
(
x
)
x
=
rearrange
(
x
,
"(n v) x c -> n v x c"
,
v
=
n_views
)
x
=
rearrange
(
x
,
"n v (h w) c -> (n v) c h w"
,
h
=
h
,
w
=
w
).
contiguous
()
return
x
+
x_in
NVComposer/core/modules/attention_temporal.py
0 → 100755
View file @
30af93f2
import
math
import
torch
import
torch
as
th
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
torch
import
nn
,
einsum
try
:
import
xformers
import
xformers.ops
XFORMERS_IS_AVAILBLE
=
True
except
:
XFORMERS_IS_AVAILBLE
=
False
from
core.common
import
gradient_checkpoint
,
exists
,
default
from
core.basics
import
conv_nd
,
zero_module
,
normalization
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.0
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
(
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
())
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
RelativePosition
(
nn
.
Module
):
def
__init__
(
self
,
num_units
,
max_relative_position
):
super
().
__init__
()
self
.
num_units
=
num_units
self
.
max_relative_position
=
max_relative_position
self
.
embeddings_table
=
nn
.
Parameter
(
th
.
Tensor
(
max_relative_position
*
2
+
1
,
num_units
)
)
nn
.
init
.
xavier_uniform_
(
self
.
embeddings_table
)
def
forward
(
self
,
length_q
,
length_k
):
device
=
self
.
embeddings_table
.
device
range_vec_q
=
th
.
arange
(
length_q
,
device
=
device
)
range_vec_k
=
th
.
arange
(
length_k
,
device
=
device
)
distance_mat
=
range_vec_k
[
None
,
:]
-
range_vec_q
[:,
None
]
distance_mat_clipped
=
th
.
clamp
(
distance_mat
,
-
self
.
max_relative_position
,
self
.
max_relative_position
)
final_mat
=
distance_mat_clipped
+
self
.
max_relative_position
final_mat
=
final_mat
.
long
()
embeddings
=
self
.
embeddings_table
[
final_mat
]
return
embeddings
class
TemporalCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
,
# For relative positional representation and image-video joint training.
temporal_length
=
None
,
image_length
=
None
,
# For image-video joint training.
# whether use relative positional representation in temporal attention.
use_relative_position
=
False
,
# For image-video joint training.
img_video_joint_train
=
False
,
use_tempoal_causal_attn
=
False
,
bidirectional_causal_attn
=
False
,
tempoal_attn_type
=
None
,
joint_train_mode
=
"same_batch"
,
**
kwargs
,
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
self
.
context_dim
=
context_dim
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
self
.
temporal_length
=
temporal_length
self
.
use_relative_position
=
use_relative_position
self
.
img_video_joint_train
=
img_video_joint_train
self
.
bidirectional_causal_attn
=
bidirectional_causal_attn
self
.
joint_train_mode
=
joint_train_mode
assert
joint_train_mode
in
[
"same_batch"
,
"diff_batch"
]
self
.
tempoal_attn_type
=
tempoal_attn_type
if
bidirectional_causal_attn
:
assert
use_tempoal_causal_attn
if
tempoal_attn_type
:
assert
tempoal_attn_type
in
[
"sparse_causal"
,
"sparse_causal_first"
]
assert
not
use_tempoal_causal_attn
assert
not
(
img_video_joint_train
and
(
self
.
joint_train_mode
==
"same_batch"
)
)
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
assert
not
(
img_video_joint_train
and
(
self
.
joint_train_mode
==
"same_batch"
)
and
use_tempoal_causal_attn
)
if
img_video_joint_train
:
if
self
.
joint_train_mode
==
"same_batch"
:
mask
=
torch
.
ones
(
[
1
,
temporal_length
+
image_length
,
temporal_length
+
image_length
]
)
mask
[:,
temporal_length
:,
:]
=
0
mask
[:,
:,
temporal_length
:]
=
0
self
.
mask
=
mask
else
:
self
.
mask
=
None
elif
use_tempoal_causal_attn
:
# normal causal attn
self
.
mask
=
torch
.
tril
(
torch
.
ones
([
1
,
temporal_length
,
temporal_length
]))
elif
tempoal_attn_type
==
"sparse_causal"
:
# true indicates keeping
mask1
=
torch
.
tril
(
torch
.
ones
([
1
,
temporal_length
,
temporal_length
])).
bool
()
# initialize to same shape with mask1
mask2
=
torch
.
zeros
([
1
,
temporal_length
,
temporal_length
])
mask2
[:,
2
:
temporal_length
,
:
temporal_length
-
2
]
=
torch
.
tril
(
torch
.
ones
([
1
,
temporal_length
-
2
,
temporal_length
-
2
])
)
mask2
=
(
1
-
mask2
).
bool
()
# false indicates masking
self
.
mask
=
mask1
&
mask2
elif
tempoal_attn_type
==
"sparse_causal_first"
:
# true indicates keeping
mask1
=
torch
.
tril
(
torch
.
ones
([
1
,
temporal_length
,
temporal_length
])).
bool
()
mask2
=
torch
.
zeros
([
1
,
temporal_length
,
temporal_length
])
mask2
[:,
2
:
temporal_length
,
1
:
temporal_length
-
1
]
=
torch
.
tril
(
torch
.
ones
([
1
,
temporal_length
-
2
,
temporal_length
-
2
])
)
mask2
=
(
1
-
mask2
).
bool
()
# false indicates masking
self
.
mask
=
mask1
&
mask2
else
:
self
.
mask
=
None
if
use_relative_position
:
assert
temporal_length
is
not
None
self
.
relative_position_k
=
RelativePosition
(
num_units
=
dim_head
,
max_relative_position
=
temporal_length
)
self
.
relative_position_v
=
RelativePosition
(
num_units
=
dim_head
,
max_relative_position
=
temporal_length
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
)
)
nn
.
init
.
constant_
(
self
.
to_q
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
to_k
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
to_v
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
to_out
[
0
].
weight
,
0
)
nn
.
init
.
constant_
(
self
.
to_out
[
0
].
bias
,
0
)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
nh
=
self
.
heads
out
=
x
q
=
self
.
to_q
(
out
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
"b n (h d) -> (b h) n d"
,
h
=
nh
),
(
q
,
k
,
v
))
sim
=
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
self
.
use_relative_position
:
len_q
,
len_k
,
len_v
=
q
.
shape
[
1
],
k
.
shape
[
1
],
v
.
shape
[
1
]
k2
=
self
.
relative_position_k
(
len_q
,
len_k
)
sim2
=
einsum
(
"b t d, t s d -> b t s"
,
q
,
k2
)
*
self
.
scale
sim
+=
sim2
if
exists
(
self
.
mask
):
if
mask
is
None
:
mask
=
self
.
mask
.
to
(
sim
.
device
)
else
:
# .to(sim.device)
mask
=
self
.
mask
.
to
(
sim
.
device
).
bool
()
&
mask
else
:
mask
=
mask
if
mask
is
not
None
:
max_neg_value
=
-
1e9
sim
=
sim
+
(
1
-
mask
.
float
())
*
max_neg_value
# 1=masking,0=no masking
attn
=
sim
.
softmax
(
dim
=-
1
)
out
=
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
if
self
.
bidirectional_causal_attn
:
mask_reverse
=
torch
.
triu
(
torch
.
ones
(
[
1
,
self
.
temporal_length
,
self
.
temporal_length
],
device
=
sim
.
device
)
)
sim_reverse
=
sim
.
float
().
masked_fill
(
mask_reverse
==
0
,
max_neg_value
)
attn_reverse
=
sim_reverse
.
softmax
(
dim
=-
1
)
out_reverse
=
einsum
(
"b i j, b j d -> b i d"
,
attn_reverse
,
v
)
out
+=
out_reverse
if
self
.
use_relative_position
:
v2
=
self
.
relative_position_v
(
len_q
,
len_v
)
out2
=
einsum
(
"b t s, t s d -> b t d"
,
attn
,
v2
)
out
+=
out2
out
=
rearrange
(
out
,
"(b h) n d -> b n (h d)"
,
h
=
nh
)
return
self
.
to_out
(
out
)
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
,
sa_shared_kv
=
False
,
shared_type
=
"only_first"
,
**
kwargs
,
):
super
().
__init__
()
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
self
.
sa_shared_kv
=
sa_shared_kv
assert
shared_type
in
[
"only_first"
,
"all_frames"
,
"first_and_prev"
,
"only_prev"
,
"full"
,
"causal"
,
"full_qkv"
,
]
self
.
shared_type
=
shared_type
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
self
.
dim_head
=
dim_head
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
)
)
if
XFORMERS_IS_AVAILBLE
:
self
.
forward
=
self
.
efficient_forward
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
h
=
self
.
heads
b
=
x
.
shape
[
0
]
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
if
self
.
sa_shared_kv
:
if
self
.
shared_type
==
"only_first"
:
k
,
v
=
map
(
lambda
xx
:
rearrange
(
xx
[
0
].
unsqueeze
(
0
),
"b n c -> (b n) c"
)
.
unsqueeze
(
0
)
.
repeat
(
b
,
1
,
1
),
(
k
,
v
),
)
else
:
raise
NotImplementedError
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
"b n (h d) -> (b h) n d"
,
h
=
h
),
(
q
,
k
,
v
))
sim
=
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
):
mask
=
rearrange
(
mask
,
"b ... -> b (...)"
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
mask
=
repeat
(
mask
,
"b j -> (b h) () j"
,
h
=
h
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
attn
=
sim
.
softmax
(
dim
=-
1
)
out
=
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
rearrange
(
out
,
"(b h) n d -> b n (h d)"
,
h
=
h
)
return
self
.
to_out
(
out
)
def
efficient_forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
b
,
_
,
_
=
q
.
shape
q
,
k
,
v
=
map
(
lambda
t
:
t
.
unsqueeze
(
3
)
.
reshape
(
b
,
t
.
shape
[
1
],
self
.
heads
,
self
.
dim_head
)
.
permute
(
0
,
2
,
1
,
3
)
.
reshape
(
b
*
self
.
heads
,
t
.
shape
[
1
],
self
.
dim_head
)
.
contiguous
(),
(
q
,
k
,
v
),
)
# actually compute the attention, what we cannot get enough of
out
=
xformers
.
ops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
None
,
op
=
None
)
if
exists
(
mask
):
raise
NotImplementedError
out
=
(
out
.
unsqueeze
(
0
)
.
reshape
(
b
,
self
.
heads
,
out
.
shape
[
1
],
self
.
dim_head
)
.
permute
(
0
,
2
,
1
,
3
)
.
reshape
(
b
,
out
.
shape
[
1
],
self
.
heads
*
self
.
dim_head
)
)
return
self
.
to_out
(
out
)
class
VideoSpatialCrossAttention
(
CrossAttention
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0
):
super
().
__init__
(
query_dim
,
context_dim
,
heads
,
dim_head
,
dropout
)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
b
,
c
,
t
,
h
,
w
=
x
.
shape
if
context
is
not
None
:
context
=
context
.
repeat
(
t
,
1
,
1
)
x
=
super
.
forward
(
spatial_attn_reshape
(
x
),
context
=
context
)
+
x
return
spatial_attn_reshape_back
(
x
,
b
,
h
)
class
BasicTransformerBlockST
(
nn
.
Module
):
def
__init__
(
self
,
# Spatial Stuff
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
,
# Temporal Stuff
temporal_length
=
None
,
image_length
=
None
,
use_relative_position
=
True
,
img_video_joint_train
=
False
,
cross_attn_on_tempoal
=
False
,
temporal_crossattn_type
=
"selfattn"
,
order
=
"stst"
,
temporalcrossfirst
=
False
,
temporal_context_dim
=
None
,
split_stcontext
=
False
,
local_spatial_temporal_attn
=
False
,
window_size
=
2
,
**
kwargs
,
):
super
().
__init__
()
# Self attention
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
**
kwargs
,
)
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
# cross attention if context is not None
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
**
kwargs
,
)
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
checkpoint
=
checkpoint
self
.
order
=
order
assert
self
.
order
in
[
"stst"
,
"sstt"
,
"st_parallel"
]
self
.
temporalcrossfirst
=
temporalcrossfirst
self
.
split_stcontext
=
split_stcontext
self
.
local_spatial_temporal_attn
=
local_spatial_temporal_attn
if
self
.
local_spatial_temporal_attn
:
assert
self
.
order
==
"stst"
assert
self
.
order
==
"stst"
self
.
window_size
=
window_size
if
not
split_stcontext
:
temporal_context_dim
=
context_dim
# Temporal attention
assert
temporal_crossattn_type
in
[
"selfattn"
,
"crossattn"
,
"skip"
]
self
.
temporal_crossattn_type
=
temporal_crossattn_type
self
.
attn1_tmp
=
TemporalCrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
temporal_length
=
temporal_length
,
image_length
=
image_length
,
use_relative_position
=
use_relative_position
,
img_video_joint_train
=
img_video_joint_train
,
**
kwargs
,
)
self
.
attn2_tmp
=
TemporalCrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
,
# cross attn
context_dim
=
(
temporal_context_dim
if
temporal_crossattn_type
==
"crossattn"
else
None
),
# temporal attn
temporal_length
=
temporal_length
,
image_length
=
image_length
,
use_relative_position
=
use_relative_position
,
img_video_joint_train
=
img_video_joint_train
,
**
kwargs
,
)
self
.
norm4
=
nn
.
LayerNorm
(
dim
)
self
.
norm5
=
nn
.
LayerNorm
(
dim
)
def
forward
(
self
,
x
,
context
=
None
,
temporal_context
=
None
,
no_temporal_attn
=
None
,
attn_mask
=
None
,
**
kwargs
,
):
if
not
self
.
split_stcontext
:
# st cross attention use the same context vector
temporal_context
=
context
.
detach
().
clone
()
if
context
is
None
and
temporal_context
is
None
:
# self-attention models
if
no_temporal_attn
:
raise
NotImplementedError
return
gradient_checkpoint
(
self
.
_forward_nocontext
,
(
x
),
self
.
parameters
(),
self
.
checkpoint
)
else
:
# cross-attention models
if
no_temporal_attn
:
forward_func
=
self
.
_forward_no_temporal_attn
else
:
forward_func
=
self
.
_forward
inputs
=
(
(
x
,
context
,
temporal_context
)
if
temporal_context
is
not
None
else
(
x
,
context
)
)
return
gradient_checkpoint
(
forward_func
,
inputs
,
self
.
parameters
(),
self
.
checkpoint
)
def
_forward
(
self
,
x
,
context
=
None
,
temporal_context
=
None
,
mask
=
None
,
no_temporal_attn
=
None
,
):
assert
x
.
dim
()
==
5
,
f
"x shape =
{
x
.
shape
}
"
b
,
c
,
t
,
h
,
w
=
x
.
shape
if
self
.
order
in
[
"stst"
,
"sstt"
]:
x
=
self
.
_st_cross_attn
(
x
,
context
,
temporal_context
=
temporal_context
,
order
=
self
.
order
,
mask
=
mask
,
)
# no_temporal_attn=no_temporal_attn,
elif
self
.
order
==
"st_parallel"
:
x
=
self
.
_st_cross_attn_parallel
(
x
,
context
,
temporal_context
=
temporal_context
,
order
=
self
.
order
,
)
# no_temporal_attn=no_temporal_attn,
else
:
raise
NotImplementedError
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
if
(
no_temporal_attn
is
None
)
or
(
not
no_temporal_attn
):
x
=
rearrange
(
x
,
"(b h w) t c -> b c t h w"
,
b
=
b
,
h
=
h
,
w
=
w
)
# 3d -> 5d
elif
no_temporal_attn
:
x
=
rearrange
(
x
,
"(b t) (h w) c -> b c t h w"
,
b
=
b
,
h
=
h
,
w
=
w
)
# 3d -> 5d
return
x
def
_forward_no_temporal_attn
(
self
,
x
,
context
=
None
,
temporal_context
=
None
,
):
assert
x
.
dim
()
==
5
,
f
"x shape =
{
x
.
shape
}
"
b
,
c
,
t
,
h
,
w
=
x
.
shape
if
self
.
order
in
[
"stst"
,
"sstt"
]:
mask
=
torch
.
zeros
([
1
,
t
,
t
],
device
=
x
.
device
).
bool
()
x
=
self
.
_st_cross_attn
(
x
,
context
,
temporal_context
=
temporal_context
,
order
=
self
.
order
,
mask
=
mask
,
)
elif
self
.
order
==
"st_parallel"
:
x
=
self
.
_st_cross_attn_parallel
(
x
,
context
,
temporal_context
=
temporal_context
,
order
=
self
.
order
,
no_temporal_attn
=
True
,
)
else
:
raise
NotImplementedError
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
x
=
rearrange
(
x
,
"(b h w) t c -> b c t h w"
,
b
=
b
,
h
=
h
,
w
=
w
)
# 3d -> 5d
return
x
def
_forward_nocontext
(
self
,
x
,
no_temporal_attn
=
None
):
assert
x
.
dim
()
==
5
,
f
"x shape =
{
x
.
shape
}
"
b
,
c
,
t
,
h
,
w
=
x
.
shape
if
self
.
order
in
[
"stst"
,
"sstt"
]:
x
=
self
.
_st_cross_attn
(
x
,
order
=
self
.
order
,
no_temporal_attn
=
no_temporal_attn
)
elif
self
.
order
==
"st_parallel"
:
x
=
self
.
_st_cross_attn_parallel
(
x
,
order
=
self
.
order
,
no_temporal_attn
=
no_temporal_attn
)
else
:
raise
NotImplementedError
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
x
=
rearrange
(
x
,
"(b h w) t c -> b c t h w"
,
b
=
b
,
h
=
h
,
w
=
w
)
# 3d -> 5d
return
x
def
_st_cross_attn
(
self
,
x
,
context
=
None
,
temporal_context
=
None
,
order
=
"stst"
,
mask
=
None
):
b
,
c
,
t
,
h
,
w
=
x
.
shape
if
order
==
"stst"
:
x
=
rearrange
(
x
,
"b c t h w -> (b t) (h w) c"
)
x
=
self
.
attn1
(
self
.
norm1
(
x
))
+
x
x
=
rearrange
(
x
,
"(b t) (h w) c -> b c t h w"
,
b
=
b
,
h
=
h
)
if
self
.
local_spatial_temporal_attn
:
x
=
local_spatial_temporal_attn_reshape
(
x
,
window_size
=
self
.
window_size
)
else
:
x
=
rearrange
(
x
,
"b c t h w -> (b h w) t c"
)
x
=
self
.
attn1_tmp
(
self
.
norm4
(
x
),
mask
=
mask
)
+
x
if
self
.
local_spatial_temporal_attn
:
x
=
local_spatial_temporal_attn_reshape_back
(
x
,
window_size
=
self
.
window_size
,
b
=
b
,
h
=
h
,
w
=
w
,
t
=
t
)
else
:
x
=
rearrange
(
x
,
"(b h w) t c -> b c t h w"
,
b
=
b
,
h
=
h
,
w
=
w
)
# 3d -> 5d
# spatial cross attention
x
=
rearrange
(
x
,
"b c t h w -> (b t) (h w) c"
)
if
context
is
not
None
:
if
context
.
shape
[
0
]
==
t
:
# img captions no_temporal_attn or
context_
=
context
else
:
context_
=
[]
for
i
in
range
(
context
.
shape
[
0
]):
context_
.
append
(
context
[
i
].
unsqueeze
(
0
).
repeat
(
t
,
1
,
1
))
context_
=
torch
.
cat
(
context_
,
dim
=
0
)
else
:
context_
=
None
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context_
)
+
x
# temporal cross attention
# if (no_temporal_attn is None) or (not no_temporal_attn):
x
=
rearrange
(
x
,
"(b t) (h w) c -> b c t h w"
,
b
=
b
,
h
=
h
)
x
=
rearrange
(
x
,
"b c t h w -> (b h w) t c"
)
if
self
.
temporal_crossattn_type
==
"crossattn"
:
# tmporal cross attention
if
temporal_context
is
not
None
:
# print(f'STATTN context={context.shape}, temporal_context={temporal_context.shape}')
temporal_context
=
torch
.
cat
(
[
context
,
temporal_context
],
dim
=
1
)
# blc
# print(f'STATTN after concat temporal_context={temporal_context.shape}')
temporal_context
=
temporal_context
.
repeat
(
h
*
w
,
1
,
1
)
# print(f'after repeat temporal_context={temporal_context.shape}')
else
:
temporal_context
=
context
[
0
:
1
,
...].
repeat
(
h
*
w
,
1
,
1
)
# print(f'STATTN after concat x={x.shape}')
x
=
(
self
.
attn2_tmp
(
self
.
norm5
(
x
),
context
=
temporal_context
,
mask
=
mask
)
+
x
)
elif
self
.
temporal_crossattn_type
==
"selfattn"
:
# temporal self attention
x
=
self
.
attn2_tmp
(
self
.
norm5
(
x
),
context
=
None
,
mask
=
mask
)
+
x
elif
self
.
temporal_crossattn_type
==
"skip"
:
# no temporal cross and self attention
pass
else
:
raise
NotImplementedError
elif
order
==
"sstt"
:
# spatial self attention
x
=
rearrange
(
x
,
"b c t h w -> (b t) (h w) c"
)
x
=
self
.
attn1
(
self
.
norm1
(
x
))
+
x
# spatial cross attention
context_
=
context
.
repeat
(
t
,
1
,
1
)
if
context
is
not
None
else
None
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context_
)
+
x
x
=
rearrange
(
x
,
"(b t) (h w) c -> b c t h w"
,
b
=
b
,
h
=
h
)
if
(
no_temporal_attn
is
None
)
or
(
not
no_temporal_attn
):
if
self
.
temporalcrossfirst
:
# temporal cross attention
if
self
.
temporal_crossattn_type
==
"crossattn"
:
# if temporal_context is not None:
temporal_context
=
context
.
repeat
(
h
*
w
,
1
,
1
)
x
=
(
self
.
attn2_tmp
(
self
.
norm5
(
x
),
context
=
temporal_context
,
mask
=
mask
)
+
x
)
elif
self
.
temporal_crossattn_type
==
"selfattn"
:
x
=
self
.
attn2_tmp
(
self
.
norm5
(
x
),
context
=
None
,
mask
=
mask
)
+
x
elif
self
.
temporal_crossattn_type
==
"skip"
:
pass
else
:
raise
NotImplementedError
# temporal self attention
x
=
rearrange
(
x
,
"b c t h w -> (b h w) t c"
)
x
=
self
.
attn1_tmp
(
self
.
norm4
(
x
),
mask
=
mask
)
+
x
else
:
# temporal self attention
x
=
rearrange
(
x
,
"b c t h w -> (b h w) t c"
)
x
=
self
.
attn1_tmp
(
self
.
norm4
(
x
),
mask
=
mask
)
+
x
# temporal cross attention
if
self
.
temporal_crossattn_type
==
"crossattn"
:
if
temporal_context
is
not
None
:
temporal_context
=
context
.
repeat
(
h
*
w
,
1
,
1
)
x
=
(
self
.
attn2_tmp
(
self
.
norm5
(
x
),
context
=
temporal_context
,
mask
=
mask
)
+
x
)
elif
self
.
temporal_crossattn_type
==
"selfattn"
:
x
=
self
.
attn2_tmp
(
self
.
norm5
(
x
),
context
=
None
,
mask
=
mask
)
+
x
elif
self
.
temporal_crossattn_type
==
"skip"
:
pass
else
:
raise
NotImplementedError
else
:
raise
NotImplementedError
return
x
def
_st_cross_attn_parallel
(
self
,
x
,
context
=
None
,
temporal_context
=
None
,
order
=
"sst"
,
no_temporal_attn
=
None
):
"""order: x -> Self Attn -> Cross Attn -> attn_s
x -> Temp Self Attn -> attn_t
x' = x + attn_s + attn_t
"""
if
no_temporal_attn
is
not
None
:
raise
NotImplementedError
B
,
C
,
T
,
H
,
W
=
x
.
shape
# spatial self attention
h
=
x
h
=
rearrange
(
h
,
"b c t h w -> (b t) (h w) c"
)
h
=
self
.
attn1
(
self
.
norm1
(
h
))
+
h
# spatial cross
# context_ = context.repeat(T, 1, 1) if context is not None else None
if
context
is
not
None
:
context_
=
[]
for
i
in
range
(
context
.
shape
[
0
]):
context_
.
append
(
context
[
i
].
unsqueeze
(
0
).
repeat
(
T
,
1
,
1
))
context_
=
torch
.
cat
(
context_
,
dim
=
0
)
else
:
context_
=
None
h
=
self
.
attn2
(
self
.
norm2
(
h
),
context
=
context_
)
+
h
h
=
rearrange
(
h
,
"(b t) (h w) c -> b c t h w"
,
b
=
B
,
h
=
H
)
# temporal self
h2
=
x
h2
=
rearrange
(
h2
,
"b c t h w -> (b h w) t c"
)
h2
=
self
.
attn1_tmp
(
self
.
norm4
(
h2
))
# + h2
h2
=
rearrange
(
h2
,
"(b h w) t c -> b c t h w"
,
b
=
B
,
h
=
H
,
w
=
W
)
out
=
h
+
h2
return
rearrange
(
out
,
"b c t h w -> (b h w) t c"
)
def
spatial_attn_reshape
(
x
):
return
rearrange
(
x
,
"b c t h w -> (b t) (h w) c"
)
def
spatial_attn_reshape_back
(
x
,
b
,
h
):
return
rearrange
(
x
,
"(b t) (h w) c -> b c t h w"
,
b
=
b
,
h
=
h
)
def
temporal_attn_reshape
(
x
):
return
rearrange
(
x
,
"b c t h w -> (b h w) t c"
)
def
temporal_attn_reshape_back
(
x
,
b
,
h
,
w
):
return
rearrange
(
x
,
"(b h w) t c -> b c t h w"
,
b
=
b
,
h
=
h
,
w
=
w
)
def
local_spatial_temporal_attn_reshape
(
x
,
window_size
):
B
,
C
,
T
,
H
,
W
=
x
.
shape
NH
=
H
//
window_size
NW
=
W
//
window_size
# x = x.view(B, C, T, NH, window_size, NW, window_size)
# tokens = x.permute(0, 1, 2, 3, 5, 4, 6).contiguous()
# tokens = tokens.view(-1, window_size, window_size, C)
x
=
rearrange
(
x
,
"b c t (nh wh) (nw ww) -> b c t nh wh nw ww"
,
nh
=
NH
,
nw
=
NW
,
wh
=
window_size
,
# # B, C, T, NH, NW, window_size, window_size
ww
=
window_size
,
).
contiguous
()
# (B, NH, NW) (T, window_size, window_size) C
x
=
rearrange
(
x
,
"b c t nh wh nw ww -> (b nh nw) (t wh ww) c"
)
return
x
def
local_spatial_temporal_attn_reshape_back
(
x
,
window_size
,
b
,
h
,
w
,
t
):
B
,
L
,
C
=
x
.
shape
NH
=
h
//
window_size
NW
=
w
//
window_size
x
=
rearrange
(
x
,
"(b nh nw) (t wh ww) c -> b c t nh wh nw ww"
,
b
=
b
,
nh
=
NH
,
nw
=
NW
,
t
=
t
,
wh
=
window_size
,
ww
=
window_size
,
)
x
=
rearrange
(
x
,
"b c t nh wh nw ww -> b c t (nh wh) (nw ww)"
)
return
x
class
SpatialTemporalTransformer
(
nn
.
Module
):
"""
Transformer block for video-like data (5D tensor).
First, project the input (aka embedding) with NO reshape.
Then apply standard transformer action.
The 5D -> 3D reshape operation will be done in the specific attention module.
"""
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
,
# Temporal stuff
temporal_length
=
None
,
image_length
=
None
,
use_relative_position
=
True
,
img_video_joint_train
=
False
,
cross_attn_on_tempoal
=
False
,
temporal_crossattn_type
=
False
,
order
=
"stst"
,
temporalcrossfirst
=
False
,
split_stcontext
=
False
,
temporal_context_dim
=
None
,
**
kwargs
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
Normalize
(
in_channels
)
self
.
proj_in
=
nn
.
Conv3d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
BasicTransformerBlockST
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
# cross attn
context_dim
=
context_dim
,
# temporal attn
temporal_length
=
temporal_length
,
image_length
=
image_length
,
use_relative_position
=
use_relative_position
,
img_video_joint_train
=
img_video_joint_train
,
temporal_crossattn_type
=
temporal_crossattn_type
,
order
=
order
,
temporalcrossfirst
=
temporalcrossfirst
,
split_stcontext
=
split_stcontext
,
temporal_context_dim
=
temporal_context_dim
,
**
kwargs
,
)
for
d
in
range
(
depth
)
]
)
self
.
proj_out
=
zero_module
(
nn
.
Conv3d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
)
def
forward
(
self
,
x
,
context
=
None
,
temporal_context
=
None
,
**
kwargs
):
# note: if no context is given, cross-attention defaults to self-attention
assert
x
.
dim
()
==
5
,
f
"x shape =
{
x
.
shape
}
"
b
,
c
,
t
,
h
,
w
=
x
.
shape
x_in
=
x
x
=
self
.
norm
(
x
)
x
=
self
.
proj_in
(
x
)
for
block
in
self
.
transformer_blocks
:
x
=
block
(
x
,
context
=
context
,
temporal_context
=
temporal_context
,
**
kwargs
)
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
class
STAttentionBlock2
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
# not used, only used in ResBlock
use_new_attention_order
=
False
,
# QKVAttention or QKVAttentionLegacy
temporal_length
=
16
,
# used in relative positional representation.
image_length
=
8
,
# used for image-video joint training.
# whether use relative positional representation in temporal attention.
use_relative_position
=
False
,
img_video_joint_train
=
False
,
# norm_type="groupnorm",
attn_norm_type
=
"group"
,
use_tempoal_causal_attn
=
False
,
):
"""
version 1: guided_diffusion implemented version
version 2: remove args input argument
"""
super
().
__init__
()
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
temporal_length
=
temporal_length
self
.
image_length
=
image_length
self
.
use_relative_position
=
use_relative_position
self
.
img_video_joint_train
=
img_video_joint_train
self
.
attn_norm_type
=
attn_norm_type
assert
self
.
attn_norm_type
in
[
"group"
,
"no_norm"
]
self
.
use_tempoal_causal_attn
=
use_tempoal_causal_attn
if
self
.
attn_norm_type
==
"group"
:
self
.
norm_s
=
normalization
(
channels
)
self
.
norm_t
=
normalization
(
channels
)
self
.
qkv_s
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
self
.
qkv_t
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
if
self
.
img_video_joint_train
:
mask
=
th
.
ones
(
[
1
,
temporal_length
+
image_length
,
temporal_length
+
image_length
]
)
mask
[:,
temporal_length
:,
:]
=
0
mask
[:,
:,
temporal_length
:]
=
0
self
.
register_buffer
(
"mask"
,
mask
)
else
:
self
.
mask
=
None
if
use_new_attention_order
:
# split qkv before split heads
self
.
attention_s
=
QKVAttention
(
self
.
num_heads
)
self
.
attention_t
=
QKVAttention
(
self
.
num_heads
)
else
:
# split heads before split qkv
self
.
attention_s
=
QKVAttentionLegacy
(
self
.
num_heads
)
self
.
attention_t
=
QKVAttentionLegacy
(
self
.
num_heads
)
if
use_relative_position
:
self
.
relative_position_k
=
RelativePosition
(
num_units
=
channels
//
self
.
num_heads
,
max_relative_position
=
temporal_length
,
)
self
.
relative_position_v
=
RelativePosition
(
num_units
=
channels
//
self
.
num_heads
,
max_relative_position
=
temporal_length
,
)
self
.
proj_out_s
=
zero_module
(
# conv_dim, in_channels, out_channels, kernel_size
conv_nd
(
1
,
channels
,
channels
,
1
)
)
self
.
proj_out_t
=
zero_module
(
# conv_dim, in_channels, out_channels, kernel_size
conv_nd
(
1
,
channels
,
channels
,
1
)
)
def
forward
(
self
,
x
,
mask
=
None
):
b
,
c
,
t
,
h
,
w
=
x
.
shape
# spatial
out
=
rearrange
(
x
,
"b c t h w -> (b t) c (h w)"
)
if
self
.
attn_norm_type
==
"no_norm"
:
qkv
=
self
.
qkv_s
(
out
)
else
:
qkv
=
self
.
qkv_s
(
self
.
norm_s
(
out
))
out
=
self
.
attention_s
(
qkv
)
out
=
self
.
proj_out_s
(
out
)
out
=
rearrange
(
out
,
"(b t) c (h w) -> b c t h w"
,
b
=
b
,
h
=
h
)
x
+=
out
# temporal
out
=
rearrange
(
x
,
"b c t h w -> (b h w) c t"
)
if
self
.
attn_norm_type
==
"no_norm"
:
qkv
=
self
.
qkv_t
(
out
)
else
:
qkv
=
self
.
qkv_t
(
self
.
norm_t
(
out
))
# relative positional embedding
if
self
.
use_relative_position
:
len_q
=
qkv
.
size
()[
-
1
]
len_k
,
len_v
=
len_q
,
len_q
k_rp
=
self
.
relative_position_k
(
len_q
,
len_k
)
v_rp
=
self
.
relative_position_v
(
len_q
,
len_v
)
# [T,T,head_dim]
out
=
self
.
attention_t
(
qkv
,
rp
=
(
k_rp
,
v_rp
),
mask
=
self
.
mask
,
use_tempoal_causal_attn
=
self
.
use_tempoal_causal_attn
,
)
else
:
out
=
self
.
attention_t
(
qkv
,
rp
=
None
,
mask
=
self
.
mask
,
use_tempoal_causal_attn
=
self
.
use_tempoal_causal_attn
,
)
out
=
self
.
proj_out_t
(
out
)
out
=
rearrange
(
out
,
"(b h w) c t -> b c t h w"
,
b
=
b
,
h
=
h
,
w
=
w
)
return
x
+
out
class
QKVAttentionLegacy
(
nn
.
Module
):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
,
rp
=
None
,
mask
=
None
):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
if
rp
is
not
None
or
mask
is
not
None
:
raise
NotImplementedError
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
th
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
th
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
th
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
@
staticmethod
def
count_flops
(
model
,
_x
,
y
):
return
count_flops_attn
(
model
,
_x
,
y
)
class
QKVAttention
(
nn
.
Module
):
"""
A module which performs QKV attention and splits in a different order.
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
,
rp
=
None
,
mask
=
None
,
use_tempoal_causal_attn
=
False
):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
# print('qkv', qkv.size())
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
# print('bs, self.n_heads, ch, length', bs, self.n_heads, ch, length)
weight
=
th
.
einsum
(
"bct,bcs->bts"
,
(
q
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
(
k
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
)
# More stable with f16 than dividing afterwards
# weight:[b,t,s] b=bs*n_heads*T
if
rp
is
not
None
:
k_rp
,
v_rp
=
rp
# [length, length, head_dim] [8, 8, 48]
weight2
=
th
.
einsum
(
"bct,tsc->bst"
,
(
q
*
scale
).
view
(
bs
*
self
.
n_heads
,
ch
,
length
),
k_rp
)
weight
+=
weight2
if
use_tempoal_causal_attn
:
# weight = torch.tril(weight)
assert
mask
is
None
,
f
"Not implemented for merging two masks!"
mask
=
torch
.
tril
(
torch
.
ones
(
weight
.
shape
))
else
:
if
mask
is
not
None
:
# only keep upper-left matrix
# process mask
c
,
t
,
_
=
weight
.
shape
if
mask
.
shape
[
-
1
]
>
t
:
mask
=
mask
[:,
:
t
,
:
t
]
elif
mask
.
shape
[
-
1
]
<
t
:
# pad ones
mask_
=
th
.
zeros
([
c
,
t
,
t
]).
to
(
mask
.
device
)
t_
=
mask
.
shape
[
-
1
]
mask_
[:,
:
t_
,
:
t_
]
=
mask
mask
=
mask_
else
:
assert
(
weight
.
shape
[
-
1
]
==
mask
.
shape
[
-
1
]
),
f
"weight=
{
weight
.
shape
}
, mask=
{
mask
.
shape
}
"
if
mask
is
not
None
:
INF
=
-
1e8
# float('-inf')
weight
=
weight
.
float
().
masked_fill
(
mask
==
0
,
INF
)
weight
=
F
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
# [256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
# weight = F.softmax(weight, dim=-1)#[256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
# [256, 48, 8] [b, head_dim, t]
a
=
th
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
.
reshape
(
bs
*
self
.
n_heads
,
ch
,
length
))
if
rp
is
not
None
:
a2
=
th
.
einsum
(
"bts,tsc->btc"
,
weight
,
v_rp
).
transpose
(
1
,
2
)
# btc->bct
a
+=
a2
return
a
.
reshape
(
bs
,
-
1
,
length
)
NVComposer/core/modules/encoders/__init__.py
0 → 100755
View file @
30af93f2
NVComposer/core/modules/encoders/adapter.py
0 → 100755
View file @
30af93f2
import
torch
import
torch.nn
as
nn
from
collections
import
OrderedDict
from
extralibs.cond_api
import
ExtraCondition
from
core.modules.x_transformer
import
FixedPositionalEmbedding
from
core.basics
import
zero_module
,
conv_nd
,
avg_pool_nd
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
,
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
op
(
x
)
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_c
,
out_c
,
down
,
ksize
=
3
,
sk
=
False
,
use_conv
=
True
):
super
().
__init__
()
ps
=
ksize
//
2
if
in_c
!=
out_c
or
sk
==
False
:
self
.
in_conv
=
nn
.
Conv2d
(
in_c
,
out_c
,
ksize
,
1
,
ps
)
else
:
self
.
in_conv
=
None
self
.
block1
=
nn
.
Conv2d
(
out_c
,
out_c
,
3
,
1
,
1
)
self
.
act
=
nn
.
ReLU
()
self
.
block2
=
nn
.
Conv2d
(
out_c
,
out_c
,
ksize
,
1
,
ps
)
if
sk
==
False
:
self
.
skep
=
nn
.
Conv2d
(
in_c
,
out_c
,
ksize
,
1
,
ps
)
else
:
self
.
skep
=
None
self
.
down
=
down
if
self
.
down
==
True
:
self
.
down_opt
=
Downsample
(
in_c
,
use_conv
=
use_conv
)
def
forward
(
self
,
x
):
if
self
.
down
==
True
:
x
=
self
.
down_opt
(
x
)
if
self
.
in_conv
is
not
None
:
x
=
self
.
in_conv
(
x
)
h
=
self
.
block1
(
x
)
h
=
self
.
act
(
h
)
h
=
self
.
block2
(
h
)
if
self
.
skep
is
not
None
:
return
h
+
self
.
skep
(
x
)
else
:
return
h
+
x
class
Adapter
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
[
320
,
640
,
1280
,
1280
],
nums_rb
=
3
,
cin
=
64
,
ksize
=
3
,
sk
=
True
,
use_conv
=
True
,
stage_downscale
=
True
,
use_identity
=
False
,
):
super
(
Adapter
,
self
).
__init__
()
if
use_identity
:
self
.
inlayer
=
nn
.
Identity
()
else
:
self
.
inlayer
=
nn
.
PixelUnshuffle
(
8
)
self
.
channels
=
channels
self
.
nums_rb
=
nums_rb
self
.
body
=
[]
for
i
in
range
(
len
(
channels
)):
for
j
in
range
(
nums_rb
):
if
(
i
!=
0
)
and
(
j
==
0
):
self
.
body
.
append
(
ResnetBlock
(
channels
[
i
-
1
],
channels
[
i
],
down
=
stage_downscale
,
ksize
=
ksize
,
sk
=
sk
,
use_conv
=
use_conv
,
)
)
else
:
self
.
body
.
append
(
ResnetBlock
(
channels
[
i
],
channels
[
i
],
down
=
False
,
ksize
=
ksize
,
sk
=
sk
,
use_conv
=
use_conv
,
)
)
self
.
body
=
nn
.
ModuleList
(
self
.
body
)
self
.
conv_in
=
nn
.
Conv2d
(
cin
,
channels
[
0
],
3
,
1
,
1
)
def
forward
(
self
,
x
):
# unshuffle
x
=
self
.
inlayer
(
x
)
# extract features
features
=
[]
x
=
self
.
conv_in
(
x
)
for
i
in
range
(
len
(
self
.
channels
)):
for
j
in
range
(
self
.
nums_rb
):
idx
=
i
*
self
.
nums_rb
+
j
x
=
self
.
body
[
idx
](
x
)
features
.
append
(
x
)
return
features
class
PositionNet
(
nn
.
Module
):
def
__init__
(
self
,
input_size
=
(
40
,
64
),
cin
=
4
,
dim
=
512
,
out_dim
=
1024
):
super
().
__init__
()
self
.
input_size
=
input_size
self
.
out_dim
=
out_dim
self
.
down_factor
=
8
# determined by the convnext backbone
feature_dim
=
dim
self
.
backbone
=
Adapter
(
channels
=
[
64
,
128
,
256
,
feature_dim
],
nums_rb
=
2
,
cin
=
cin
,
stage_downscale
=
True
,
use_identity
=
True
,
)
self
.
num_tokens
=
(
self
.
input_size
[
0
]
//
self
.
down_factor
)
*
(
self
.
input_size
[
1
]
//
self
.
down_factor
)
self
.
pos_embedding
=
nn
.
Parameter
(
torch
.
empty
(
1
,
self
.
num_tokens
,
feature_dim
).
normal_
(
std
=
0.02
)
)
# from BERT
self
.
linears
=
nn
.
Sequential
(
nn
.
Linear
(
feature_dim
,
512
),
nn
.
SiLU
(),
nn
.
Linear
(
512
,
512
),
nn
.
SiLU
(),
nn
.
Linear
(
512
,
out_dim
),
)
# self.null_feature = torch.nn.Parameter(torch.zeros([feature_dim]))
def
forward
(
self
,
x
,
mask
=
None
):
B
=
x
.
shape
[
0
]
# token from edge map
# x = torch.nn.functional.interpolate(x, self.input_size)
feature
=
self
.
backbone
(
x
)[
-
1
]
objs
=
feature
.
reshape
(
B
,
-
1
,
self
.
num_tokens
)
objs
=
objs
.
permute
(
0
,
2
,
1
)
# N*Num_tokens*dim
"""
# expand null token
null_objs = self.null_feature.view(1,1,-1)
null_objs = null_objs.repeat(B,self.num_tokens,1)
# mask replacing
mask = mask.view(-1,1,1)
objs = objs*mask + null_objs*(1-mask)
"""
# add pos
objs
=
objs
+
self
.
pos_embedding
# fuse them
objs
=
self
.
linears
(
objs
)
assert
objs
.
shape
==
torch
.
Size
([
B
,
self
.
num_tokens
,
self
.
out_dim
])
return
objs
class
PositionNet2
(
nn
.
Module
):
def
__init__
(
self
,
input_size
=
(
40
,
64
),
cin
=
4
,
dim
=
320
,
out_dim
=
1024
):
super
().
__init__
()
self
.
input_size
=
input_size
self
.
out_dim
=
out_dim
self
.
down_factor
=
8
# determined by the convnext backbone
self
.
dim
=
dim
self
.
backbone
=
Adapter
(
channels
=
[
dim
,
dim
,
dim
,
dim
],
nums_rb
=
2
,
cin
=
cin
,
stage_downscale
=
True
,
use_identity
=
True
,
)
self
.
pos_embedding
=
FixedPositionalEmbedding
(
dim
=
self
.
dim
)
self
.
linears
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
512
),
nn
.
SiLU
(),
nn
.
Linear
(
512
,
512
),
nn
.
SiLU
(),
nn
.
Linear
(
512
,
out_dim
),
)
def
forward
(
self
,
x
,
mask
=
None
):
B
=
x
.
shape
[
0
]
features
=
self
.
backbone
(
x
)
token_lists
=
[]
for
feature
in
features
:
objs
=
feature
.
reshape
(
B
,
self
.
dim
,
-
1
)
objs
=
objs
.
permute
(
0
,
2
,
1
)
# N*Num_tokens*dim
# add pos
objs
=
objs
+
self
.
pos_embedding
(
objs
)
# fuse them
objs
=
self
.
linears
(
objs
)
token_lists
.
append
(
objs
)
return
token_lists
class
LayerNorm
(
nn
.
LayerNorm
):
"""Subclass torch's LayerNorm to handle fp16."""
def
forward
(
self
,
x
:
torch
.
Tensor
):
orig_type
=
x
.
dtype
ret
=
super
().
forward
(
x
.
type
(
torch
.
float32
))
return
ret
.
type
(
orig_type
)
class
QuickGELU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
x
*
torch
.
sigmoid
(
1.702
*
x
)
class
ResidualAttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
d_model
:
int
,
n_head
:
int
,
attn_mask
:
torch
.
Tensor
=
None
):
super
().
__init__
()
self
.
attn
=
nn
.
MultiheadAttention
(
d_model
,
n_head
)
self
.
ln_1
=
LayerNorm
(
d_model
)
self
.
mlp
=
nn
.
Sequential
(
OrderedDict
(
[
(
"c_fc"
,
nn
.
Linear
(
d_model
,
d_model
*
4
)),
(
"gelu"
,
QuickGELU
()),
(
"c_proj"
,
nn
.
Linear
(
d_model
*
4
,
d_model
)),
]
)
)
self
.
ln_2
=
LayerNorm
(
d_model
)
self
.
attn_mask
=
attn_mask
def
attention
(
self
,
x
:
torch
.
Tensor
):
self
.
attn_mask
=
(
self
.
attn_mask
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)
if
self
.
attn_mask
is
not
None
else
None
)
return
self
.
attn
(
x
,
x
,
x
,
need_weights
=
False
,
attn_mask
=
self
.
attn_mask
)[
0
]
def
forward
(
self
,
x
:
torch
.
Tensor
):
x
=
x
+
self
.
attention
(
self
.
ln_1
(
x
))
x
=
x
+
self
.
mlp
(
self
.
ln_2
(
x
))
return
x
class
StyleAdapter
(
nn
.
Module
):
def
__init__
(
self
,
width
=
1024
,
context_dim
=
768
,
num_head
=
8
,
n_layes
=
3
,
num_token
=
4
):
super
().
__init__
()
scale
=
width
**-
0.5
self
.
transformer_layes
=
nn
.
Sequential
(
*
[
ResidualAttentionBlock
(
width
,
num_head
)
for
_
in
range
(
n_layes
)]
)
self
.
num_token
=
num_token
self
.
style_embedding
=
nn
.
Parameter
(
torch
.
randn
(
1
,
num_token
,
width
)
*
scale
)
self
.
ln_post
=
LayerNorm
(
width
)
self
.
ln_pre
=
LayerNorm
(
width
)
self
.
proj
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
width
,
context_dim
))
def
forward
(
self
,
x
):
# x shape [N, HW+1, C]
style_embedding
=
self
.
style_embedding
+
torch
.
zeros
(
(
x
.
shape
[
0
],
self
.
num_token
,
self
.
style_embedding
.
shape
[
-
1
]),
device
=
x
.
device
,
)
x
=
torch
.
cat
([
x
,
style_embedding
],
dim
=
1
)
x
=
self
.
ln_pre
(
x
)
x
=
x
.
permute
(
1
,
0
,
2
)
# NLD -> LND
x
=
self
.
transformer_layes
(
x
)
x
=
x
.
permute
(
1
,
0
,
2
)
# LND -> NLD
x
=
self
.
ln_post
(
x
[:,
-
self
.
num_token
:,
:])
x
=
x
@
self
.
proj
return
x
class
ResnetBlock_light
(
nn
.
Module
):
def
__init__
(
self
,
in_c
):
super
().
__init__
()
self
.
block1
=
nn
.
Conv2d
(
in_c
,
in_c
,
3
,
1
,
1
)
self
.
act
=
nn
.
ReLU
()
self
.
block2
=
nn
.
Conv2d
(
in_c
,
in_c
,
3
,
1
,
1
)
def
forward
(
self
,
x
):
h
=
self
.
block1
(
x
)
h
=
self
.
act
(
h
)
h
=
self
.
block2
(
h
)
return
h
+
x
class
extractor
(
nn
.
Module
):
def
__init__
(
self
,
in_c
,
inter_c
,
out_c
,
nums_rb
,
down
=
False
):
super
().
__init__
()
self
.
in_conv
=
nn
.
Conv2d
(
in_c
,
inter_c
,
1
,
1
,
0
)
self
.
body
=
[]
for
_
in
range
(
nums_rb
):
self
.
body
.
append
(
ResnetBlock_light
(
inter_c
))
self
.
body
=
nn
.
Sequential
(
*
self
.
body
)
self
.
out_conv
=
nn
.
Conv2d
(
inter_c
,
out_c
,
1
,
1
,
0
)
self
.
down
=
down
if
self
.
down
==
True
:
self
.
down_opt
=
Downsample
(
in_c
,
use_conv
=
False
)
def
forward
(
self
,
x
):
if
self
.
down
==
True
:
x
=
self
.
down_opt
(
x
)
x
=
self
.
in_conv
(
x
)
x
=
self
.
body
(
x
)
x
=
self
.
out_conv
(
x
)
return
x
class
Adapter_light
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
[
320
,
640
,
1280
,
1280
],
nums_rb
=
3
,
cin
=
64
):
super
(
Adapter_light
,
self
).
__init__
()
self
.
unshuffle
=
nn
.
PixelUnshuffle
(
8
)
self
.
channels
=
channels
self
.
nums_rb
=
nums_rb
self
.
body
=
[]
for
i
in
range
(
len
(
channels
)):
if
i
==
0
:
self
.
body
.
append
(
extractor
(
in_c
=
cin
,
inter_c
=
channels
[
i
]
//
4
,
out_c
=
channels
[
i
],
nums_rb
=
nums_rb
,
down
=
False
,
)
)
else
:
self
.
body
.
append
(
extractor
(
in_c
=
channels
[
i
-
1
],
inter_c
=
channels
[
i
]
//
4
,
out_c
=
channels
[
i
],
nums_rb
=
nums_rb
,
down
=
True
,
)
)
self
.
body
=
nn
.
ModuleList
(
self
.
body
)
def
forward
(
self
,
x
):
# unshuffle
x
=
self
.
unshuffle
(
x
)
# extract features
features
=
[]
for
i
in
range
(
len
(
self
.
channels
)):
x
=
self
.
body
[
i
](
x
)
features
.
append
(
x
)
return
features
class
CoAdapterFuser
(
nn
.
Module
):
def
__init__
(
self
,
unet_channels
=
[
320
,
640
,
1280
,
1280
],
width
=
768
,
num_head
=
8
,
n_layes
=
3
):
super
(
CoAdapterFuser
,
self
).
__init__
()
scale
=
width
**
0.5
self
.
task_embedding
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
16
,
width
))
self
.
positional_embedding
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
len
(
unet_channels
),
width
)
)
self
.
spatial_feat_mapping
=
nn
.
ModuleList
()
for
ch
in
unet_channels
:
self
.
spatial_feat_mapping
.
append
(
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
ch
,
width
),
)
)
self
.
transformer_layes
=
nn
.
Sequential
(
*
[
ResidualAttentionBlock
(
width
,
num_head
)
for
_
in
range
(
n_layes
)]
)
self
.
ln_post
=
LayerNorm
(
width
)
self
.
ln_pre
=
LayerNorm
(
width
)
self
.
spatial_ch_projs
=
nn
.
ModuleList
()
for
ch
in
unet_channels
:
self
.
spatial_ch_projs
.
append
(
zero_module
(
nn
.
Linear
(
width
,
ch
)))
self
.
seq_proj
=
nn
.
Parameter
(
torch
.
zeros
(
width
,
width
))
def
forward
(
self
,
features
):
if
len
(
features
)
==
0
:
return
None
,
None
inputs
=
[]
for
cond_name
in
features
.
keys
():
task_idx
=
getattr
(
ExtraCondition
,
cond_name
).
value
if
not
isinstance
(
features
[
cond_name
],
list
):
inputs
.
append
(
features
[
cond_name
]
+
self
.
task_embedding
[
task_idx
])
continue
feat_seq
=
[]
for
idx
,
feature_map
in
enumerate
(
features
[
cond_name
]):
feature_vec
=
torch
.
mean
(
feature_map
,
dim
=
(
2
,
3
))
feature_vec
=
self
.
spatial_feat_mapping
[
idx
](
feature_vec
)
feat_seq
.
append
(
feature_vec
)
feat_seq
=
torch
.
stack
(
feat_seq
,
dim
=
1
)
# Nx4xC
feat_seq
=
feat_seq
+
self
.
task_embedding
[
task_idx
]
feat_seq
=
feat_seq
+
self
.
positional_embedding
inputs
.
append
(
feat_seq
)
x
=
torch
.
cat
(
inputs
,
dim
=
1
)
# NxLxC
x
=
self
.
ln_pre
(
x
)
x
=
x
.
permute
(
1
,
0
,
2
)
# NLD -> LND
x
=
self
.
transformer_layes
(
x
)
x
=
x
.
permute
(
1
,
0
,
2
)
# LND -> NLD
x
=
self
.
ln_post
(
x
)
ret_feat_map
=
None
ret_feat_seq
=
None
cur_seq_idx
=
0
for
cond_name
in
features
.
keys
():
if
not
isinstance
(
features
[
cond_name
],
list
):
length
=
features
[
cond_name
].
size
(
1
)
transformed_feature
=
features
[
cond_name
]
*
(
(
x
[:,
cur_seq_idx
:
cur_seq_idx
+
length
]
@
self
.
seq_proj
)
+
1
)
if
ret_feat_seq
is
None
:
ret_feat_seq
=
transformed_feature
else
:
ret_feat_seq
=
torch
.
cat
([
ret_feat_seq
,
transformed_feature
],
dim
=
1
)
cur_seq_idx
+=
length
continue
length
=
len
(
features
[
cond_name
])
transformed_feature_list
=
[]
for
idx
in
range
(
length
):
alpha
=
self
.
spatial_ch_projs
[
idx
](
x
[:,
cur_seq_idx
+
idx
])
alpha
=
alpha
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
+
1
transformed_feature_list
.
append
(
features
[
cond_name
][
idx
]
*
alpha
)
if
ret_feat_map
is
None
:
ret_feat_map
=
transformed_feature_list
else
:
ret_feat_map
=
list
(
map
(
lambda
x
,
y
:
x
+
y
,
ret_feat_map
,
transformed_feature_list
)
)
cur_seq_idx
+=
length
assert
cur_seq_idx
==
x
.
size
(
1
)
return
ret_feat_map
,
ret_feat_seq
NVComposer/core/modules/encoders/condition.py
0 → 100755
View file @
30af93f2
import
torch
import
torch.nn
as
nn
import
kornia
from
torch.utils.checkpoint
import
checkpoint
from
transformers
import
T5Tokenizer
,
T5EncoderModel
,
CLIPTokenizer
,
CLIPTextModel
import
open_clip
from
core.common
import
autocast
from
utils.utils
import
count_params
class
AbstractEncoder
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
encode
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
class
IdentityEncoder
(
AbstractEncoder
):
def
encode
(
self
,
x
):
return
x
class
ClassEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
n_classes
=
1000
,
key
=
"class"
,
ucg_rate
=
0.1
):
super
().
__init__
()
self
.
key
=
key
self
.
embedding
=
nn
.
Embedding
(
n_classes
,
embed_dim
)
self
.
n_classes
=
n_classes
self
.
ucg_rate
=
ucg_rate
def
forward
(
self
,
batch
,
key
=
None
,
disable_dropout
=
False
):
if
key
is
None
:
key
=
self
.
key
# this is for use in crossattn
c
=
batch
[
key
][:,
None
]
if
self
.
ucg_rate
>
0.0
and
not
disable_dropout
:
mask
=
1.0
-
torch
.
bernoulli
(
torch
.
ones_like
(
c
)
*
self
.
ucg_rate
)
c
=
mask
*
c
+
(
1
-
mask
)
*
torch
.
ones_like
(
c
)
*
(
self
.
n_classes
-
1
)
c
=
c
.
long
()
c
=
self
.
embedding
(
c
)
return
c
def
get_unconditional_conditioning
(
self
,
bs
,
device
=
"cuda"
):
# 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc_class
=
self
.
n_classes
-
1
uc
=
torch
.
ones
((
bs
,),
device
=
device
)
*
uc_class
uc
=
{
self
.
key
:
uc
}
return
uc
def
disabled_train
(
self
,
mode
=
True
):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return
self
class
FrozenT5Embedder
(
AbstractEncoder
):
"""Uses the T5 transformer encoder for text"""
def
__init__
(
self
,
version
=
"google/t5-v1_1-large"
,
device
=
"cuda"
,
max_length
=
77
,
freeze
=
True
):
super
().
__init__
()
self
.
tokenizer
=
T5Tokenizer
.
from_pretrained
(
version
)
self
.
transformer
=
T5EncoderModel
.
from_pretrained
(
version
)
self
.
device
=
device
self
.
max_length
=
max_length
if
freeze
:
self
.
freeze
()
def
freeze
(
self
):
self
.
transformer
=
self
.
transformer
.
eval
()
# self.train = disabled_train
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
text
):
batch_encoding
=
self
.
tokenizer
(
text
,
truncation
=
True
,
max_length
=
self
.
max_length
,
return_length
=
True
,
return_overflowing_tokens
=
False
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
)
tokens
=
batch_encoding
[
"input_ids"
].
to
(
self
.
device
)
outputs
=
self
.
transformer
(
input_ids
=
tokens
)
z
=
outputs
.
last_hidden_state
return
z
def
encode
(
self
,
text
):
return
self
(
text
)
class
FrozenCLIPEmbedder
(
AbstractEncoder
):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS
=
[
"last"
,
"pooled"
,
"hidden"
]
def
__init__
(
self
,
version
=
"openai/clip-vit-large-patch14"
,
device
=
"cuda"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"last"
,
layer_idx
=
None
,
):
# clip-vit-base-patch32
super
().
__init__
()
assert
layer
in
self
.
LAYERS
self
.
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
version
)
self
.
transformer
=
CLIPTextModel
.
from_pretrained
(
version
)
self
.
device
=
device
self
.
max_length
=
max_length
if
freeze
:
self
.
freeze
()
self
.
layer
=
layer
self
.
layer_idx
=
layer_idx
if
layer
==
"hidden"
:
assert
layer_idx
is
not
None
assert
0
<=
abs
(
layer_idx
)
<=
12
def
freeze
(
self
):
self
.
transformer
=
self
.
transformer
.
eval
()
# self.train = disabled_train
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
text
):
batch_encoding
=
self
.
tokenizer
(
text
,
truncation
=
True
,
max_length
=
self
.
max_length
,
return_length
=
True
,
return_overflowing_tokens
=
False
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
)
tokens
=
batch_encoding
[
"input_ids"
].
to
(
self
.
device
)
outputs
=
self
.
transformer
(
input_ids
=
tokens
,
output_hidden_states
=
self
.
layer
==
"hidden"
)
if
self
.
layer
==
"last"
:
z
=
outputs
.
last_hidden_state
elif
self
.
layer
==
"pooled"
:
z
=
outputs
.
pooler_output
[:,
None
,
:]
else
:
z
=
outputs
.
hidden_states
[
self
.
layer_idx
]
return
z
def
encode
(
self
,
text
):
return
self
(
text
)
class
ClipImageEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
model
,
jit
=
False
,
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
,
antialias
=
True
,
ucg_rate
=
0.0
,
):
super
().
__init__
()
from
clip
import
load
as
load_clip
self
.
model
,
_
=
load_clip
(
name
=
model
,
device
=
device
,
jit
=
jit
)
self
.
antialias
=
antialias
self
.
register_buffer
(
"mean"
,
torch
.
Tensor
([
0.48145466
,
0.4578275
,
0.40821073
]),
persistent
=
False
)
self
.
register_buffer
(
"std"
,
torch
.
Tensor
([
0.26862954
,
0.26130258
,
0.27577711
]),
persistent
=
False
)
self
.
ucg_rate
=
ucg_rate
def
preprocess
(
self
,
x
):
# normalize to [0,1]
x
=
kornia
.
geometry
.
resize
(
x
,
(
224
,
224
),
interpolation
=
"bicubic"
,
align_corners
=
True
,
antialias
=
self
.
antialias
,
)
x
=
(
x
+
1.0
)
/
2.0
# re-normalize according to clip
x
=
kornia
.
enhance
.
normalize
(
x
,
self
.
mean
,
self
.
std
)
return
x
def
forward
(
self
,
x
,
no_dropout
=
False
):
# x is assumed to be in range [-1,1]
out
=
self
.
model
.
encode_image
(
self
.
preprocess
(
x
))
out
=
out
.
to
(
x
.
dtype
)
if
self
.
ucg_rate
>
0.0
and
not
no_dropout
:
out
=
(
torch
.
bernoulli
(
(
1.0
-
self
.
ucg_rate
)
*
torch
.
ones
(
out
.
shape
[
0
],
device
=
out
.
device
)
)[:,
None
]
*
out
)
return
out
class
FrozenOpenCLIPEmbedder
(
AbstractEncoder
):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS
=
[
# "pooled",
"last"
,
"penultimate"
,
]
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
version
=
None
,
device
=
"cuda"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"last"
,
):
super
().
__init__
()
assert
layer
in
self
.
LAYERS
model
,
_
,
_
=
open_clip
.
create_model_and_transforms
(
arch
,
device
=
torch
.
device
(
"cpu"
),
pretrained
=
version
)
del
model
.
visual
self
.
model
=
model
self
.
device
=
device
self
.
max_length
=
max_length
if
freeze
:
self
.
freeze
()
self
.
layer
=
layer
if
self
.
layer
==
"last"
:
self
.
layer_idx
=
0
elif
self
.
layer
==
"penultimate"
:
self
.
layer_idx
=
1
else
:
raise
NotImplementedError
()
def
freeze
(
self
):
self
.
model
=
self
.
model
.
eval
()
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
text
):
tokens
=
open_clip
.
tokenize
(
text
)
z
=
self
.
encode_with_transformer
(
tokens
.
to
(
self
.
device
))
return
z
def
encode_with_transformer
(
self
,
text
):
x
=
self
.
model
.
token_embedding
(
text
)
# [batch_size, n_ctx, d_model]
x
=
x
+
self
.
model
.
positional_embedding
x
=
x
.
permute
(
1
,
0
,
2
)
# NLD -> LND
x
=
self
.
text_transformer_forward
(
x
,
attn_mask
=
self
.
model
.
attn_mask
)
x
=
x
.
permute
(
1
,
0
,
2
)
# LND -> NLD
x
=
self
.
model
.
ln_final
(
x
)
return
x
def
text_transformer_forward
(
self
,
x
:
torch
.
Tensor
,
attn_mask
=
None
):
for
i
,
r
in
enumerate
(
self
.
model
.
transformer
.
resblocks
):
if
i
==
len
(
self
.
model
.
transformer
.
resblocks
)
-
self
.
layer_idx
:
break
if
(
self
.
model
.
transformer
.
grad_checkpointing
and
not
torch
.
jit
.
is_scripting
()
):
x
=
checkpoint
(
r
,
x
,
attn_mask
)
else
:
x
=
r
(
x
,
attn_mask
=
attn_mask
)
return
x
def
encode
(
self
,
text
):
return
self
(
text
)
class
FrozenOpenCLIPImageEmbedder
(
AbstractEncoder
):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
version
=
None
,
device
=
"cuda"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"pooled"
,
antialias
=
True
,
ucg_rate
=
0.0
,
):
super
().
__init__
()
model
,
_
,
_
=
open_clip
.
create_model_and_transforms
(
arch
,
device
=
torch
.
device
(
"cpu"
),
pretrained
=
version
)
del
model
.
transformer
self
.
model
=
model
self
.
device
=
device
self
.
max_length
=
max_length
if
freeze
:
self
.
freeze
()
self
.
layer
=
layer
if
self
.
layer
==
"penultimate"
:
raise
NotImplementedError
()
self
.
layer_idx
=
1
self
.
antialias
=
antialias
self
.
register_buffer
(
"mean"
,
torch
.
Tensor
([
0.48145466
,
0.4578275
,
0.40821073
]),
persistent
=
False
)
self
.
register_buffer
(
"std"
,
torch
.
Tensor
([
0.26862954
,
0.26130258
,
0.27577711
]),
persistent
=
False
)
self
.
ucg_rate
=
ucg_rate
def
preprocess
(
self
,
x
):
# normalize to [0,1]
x
=
kornia
.
geometry
.
resize
(
x
,
(
224
,
224
),
interpolation
=
"bicubic"
,
align_corners
=
True
,
antialias
=
self
.
antialias
,
)
x
=
(
x
+
1.0
)
/
2.0
# renormalize according to clip
x
=
kornia
.
enhance
.
normalize
(
x
,
self
.
mean
,
self
.
std
)
return
x
def
freeze
(
self
):
self
.
model
=
self
.
model
.
eval
()
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
@
autocast
def
forward
(
self
,
image
,
no_dropout
=
False
):
z
=
self
.
encode_with_vision_transformer
(
image
)
if
self
.
ucg_rate
>
0.0
and
not
no_dropout
:
z
=
(
torch
.
bernoulli
(
(
1.0
-
self
.
ucg_rate
)
*
torch
.
ones
(
z
.
shape
[
0
],
device
=
z
.
device
)
)[:,
None
]
*
z
)
return
z
def
encode_with_vision_transformer
(
self
,
img
):
img
=
self
.
preprocess
(
img
)
x
=
self
.
model
.
visual
(
img
)
return
x
def
encode
(
self
,
text
):
return
self
(
text
)
class
FrozenOpenCLIPImageEmbedderV2
(
AbstractEncoder
):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
version
=
None
,
device
=
"cuda"
,
freeze
=
True
,
layer
=
"pooled"
,
antialias
=
True
,
):
super
().
__init__
()
model
,
_
,
_
=
open_clip
.
create_model_and_transforms
(
arch
,
device
=
torch
.
device
(
"cpu"
),
pretrained
=
version
,
)
del
model
.
transformer
self
.
model
=
model
self
.
device
=
device
if
freeze
:
self
.
freeze
()
self
.
layer
=
layer
if
self
.
layer
==
"penultimate"
:
raise
NotImplementedError
()
self
.
layer_idx
=
1
self
.
antialias
=
antialias
self
.
register_buffer
(
"mean"
,
torch
.
Tensor
([
0.48145466
,
0.4578275
,
0.40821073
]),
persistent
=
False
)
self
.
register_buffer
(
"std"
,
torch
.
Tensor
([
0.26862954
,
0.26130258
,
0.27577711
]),
persistent
=
False
)
def
preprocess
(
self
,
x
):
# normalize to [0,1]
x
=
kornia
.
geometry
.
resize
(
x
,
(
224
,
224
),
interpolation
=
"bicubic"
,
align_corners
=
True
,
antialias
=
self
.
antialias
,
)
x
=
(
x
+
1.0
)
/
2.0
# renormalize according to clip
x
=
kornia
.
enhance
.
normalize
(
x
,
self
.
mean
,
self
.
std
)
return
x
def
freeze
(
self
):
self
.
model
=
self
.
model
.
eval
()
for
param
in
self
.
model
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
image
,
no_dropout
=
False
):
# image: b c h w
z
=
self
.
encode_with_vision_transformer
(
image
)
return
z
def
encode_with_vision_transformer
(
self
,
x
):
x
=
self
.
preprocess
(
x
)
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
if
self
.
model
.
visual
.
input_patchnorm
:
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
],
self
.
model
.
visual
.
grid_size
[
0
],
self
.
model
.
visual
.
patch_size
[
0
],
self
.
model
.
visual
.
grid_size
[
1
],
self
.
model
.
visual
.
patch_size
[
1
],
)
x
=
x
.
permute
(
0
,
2
,
4
,
1
,
3
,
5
)
x
=
x
.
reshape
(
x
.
shape
[
0
],
self
.
model
.
visual
.
grid_size
[
0
]
*
self
.
model
.
visual
.
grid_size
[
1
],
-
1
,
)
x
=
self
.
model
.
visual
.
patchnorm_pre_ln
(
x
)
x
=
self
.
model
.
visual
.
conv1
(
x
)
else
:
x
=
self
.
model
.
visual
.
conv1
(
x
)
# shape = [*, width, grid, grid]
# shape = [*, width, grid ** 2]
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
],
-
1
)
x
=
x
.
permute
(
0
,
2
,
1
)
# shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x
=
torch
.
cat
(
[
self
.
model
.
visual
.
class_embedding
.
to
(
x
.
dtype
)
+
torch
.
zeros
(
x
.
shape
[
0
],
1
,
x
.
shape
[
-
1
],
dtype
=
x
.
dtype
,
device
=
x
.
device
),
x
,
],
dim
=
1
,
)
# shape = [*, grid ** 2 + 1, width]
x
=
x
+
self
.
model
.
visual
.
positional_embedding
.
to
(
x
.
dtype
)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x
=
self
.
model
.
visual
.
patch_dropout
(
x
)
x
=
self
.
model
.
visual
.
ln_pre
(
x
)
x
=
x
.
permute
(
1
,
0
,
2
)
# NLD -> LND
x
=
self
.
model
.
visual
.
transformer
(
x
)
x
=
x
.
permute
(
1
,
0
,
2
)
# LND -> NLD
return
x
class
FrozenCLIPT5Encoder
(
AbstractEncoder
):
def
__init__
(
self
,
clip_version
=
"openai/clip-vit-large-patch14"
,
t5_version
=
"google/t5-v1_1-xl"
,
device
=
"cuda"
,
clip_max_length
=
77
,
t5_max_length
=
77
,
):
super
().
__init__
()
self
.
clip_encoder
=
FrozenCLIPEmbedder
(
clip_version
,
device
,
max_length
=
clip_max_length
)
self
.
t5_encoder
=
FrozenT5Embedder
(
t5_version
,
device
,
max_length
=
t5_max_length
)
print
(
f
"
{
self
.
clip_encoder
.
__class__
.
__name__
}
has
{
count_params
(
self
.
clip_encoder
)
*
1.e-6
:.
2
f
}
M parameters, "
f
"
{
self
.
t5_encoder
.
__class__
.
__name__
}
comes with
{
count_params
(
self
.
t5_encoder
)
*
1.e-6
:.
2
f
}
M params."
)
def
encode
(
self
,
text
):
return
self
(
text
)
def
forward
(
self
,
text
):
clip_z
=
self
.
clip_encoder
.
encode
(
text
)
t5_z
=
self
.
t5_encoder
.
encode
(
text
)
return
[
clip_z
,
t5_z
]
NVComposer/core/modules/encoders/resampler.py
0 → 100755
View file @
30af93f2
import
math
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
,
repeat
class
ImageProjModel
(
nn
.
Module
):
"""Projection Model"""
def
__init__
(
self
,
cross_attention_dim
=
1024
,
clip_embeddings_dim
=
1024
,
clip_extra_context_tokens
=
4
,
):
super
().
__init__
()
self
.
cross_attention_dim
=
cross_attention_dim
self
.
clip_extra_context_tokens
=
clip_extra_context_tokens
self
.
proj
=
nn
.
Linear
(
clip_embeddings_dim
,
self
.
clip_extra_context_tokens
*
cross_attention_dim
)
self
.
norm
=
nn
.
LayerNorm
(
cross_attention_dim
)
def
forward
(
self
,
image_embeds
):
# embeds = image_embeds
embeds
=
image_embeds
.
type
(
list
(
self
.
proj
.
parameters
())[
0
].
dtype
)
clip_extra_context_tokens
=
self
.
proj
(
embeds
).
reshape
(
-
1
,
self
.
clip_extra_context_tokens
,
self
.
cross_attention_dim
)
clip_extra_context_tokens
=
self
.
norm
(
clip_extra_context_tokens
)
return
clip_extra_context_tokens
# FFN
def
FeedForward
(
dim
,
mult
=
4
):
inner_dim
=
int
(
dim
*
mult
)
return
nn
.
Sequential
(
nn
.
LayerNorm
(
dim
),
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
),
nn
.
GELU
(),
nn
.
Linear
(
inner_dim
,
dim
,
bias
=
False
),
)
def
reshape_tensor
(
x
,
heads
):
bs
,
length
,
width
=
x
.
shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x
=
x
.
view
(
bs
,
length
,
heads
,
-
1
)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x
=
x
.
transpose
(
1
,
2
)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x
=
x
.
reshape
(
bs
,
heads
,
length
,
-
1
)
return
x
class
PerceiverAttention
(
nn
.
Module
):
def
__init__
(
self
,
*
,
dim
,
dim_head
=
64
,
heads
=
8
):
super
().
__init__
()
self
.
scale
=
dim_head
**-
0.5
self
.
dim_head
=
dim_head
self
.
heads
=
heads
inner_dim
=
dim_head
*
heads
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
to_q
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
to_kv
=
nn
.
Linear
(
dim
,
inner_dim
*
2
,
bias
=
False
)
self
.
to_out
=
nn
.
Linear
(
inner_dim
,
dim
,
bias
=
False
)
def
forward
(
self
,
x
,
latents
):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x
=
self
.
norm1
(
x
)
latents
=
self
.
norm2
(
latents
)
b
,
l
,
_
=
latents
.
shape
q
=
self
.
to_q
(
latents
)
kv_input
=
torch
.
cat
((
x
,
latents
),
dim
=-
2
)
k
,
v
=
self
.
to_kv
(
kv_input
).
chunk
(
2
,
dim
=-
1
)
q
=
reshape_tensor
(
q
,
self
.
heads
)
k
=
reshape_tensor
(
k
,
self
.
heads
)
v
=
reshape_tensor
(
v
,
self
.
heads
)
# attention
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
self
.
dim_head
))
# More stable with f16 than dividing afterwards
weight
=
(
q
*
scale
)
@
(
k
*
scale
).
transpose
(
-
2
,
-
1
)
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
out
=
weight
@
v
out
=
out
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
b
,
l
,
-
1
)
return
self
.
to_out
(
out
)
class
Resampler
(
nn
.
Module
):
def
__init__
(
self
,
dim
=
1024
,
depth
=
8
,
dim_head
=
64
,
heads
=
16
,
num_queries
=
8
,
embedding_dim
=
768
,
output_dim
=
1024
,
ff_mult
=
4
,
video_length
=
None
,
):
super
().
__init__
()
self
.
num_queries
=
num_queries
self
.
video_length
=
video_length
if
video_length
is
not
None
:
num_queries
=
num_queries
*
video_length
self
.
latents
=
nn
.
Parameter
(
torch
.
randn
(
1
,
num_queries
,
dim
)
/
dim
**
0.5
)
self
.
proj_in
=
nn
.
Linear
(
embedding_dim
,
dim
)
self
.
proj_out
=
nn
.
Linear
(
dim
,
output_dim
)
self
.
norm_out
=
nn
.
LayerNorm
(
output_dim
)
self
.
layers
=
nn
.
ModuleList
([])
for
_
in
range
(
depth
):
self
.
layers
.
append
(
nn
.
ModuleList
(
[
PerceiverAttention
(
dim
=
dim
,
dim_head
=
dim_head
,
heads
=
heads
),
FeedForward
(
dim
=
dim
,
mult
=
ff_mult
),
]
)
)
def
forward
(
self
,
x
):
latents
=
self
.
latents
.
repeat
(
x
.
size
(
0
),
1
,
1
)
# B (T L) C
x
=
self
.
proj_in
(
x
)
for
attn
,
ff
in
self
.
layers
:
latents
=
attn
(
x
,
latents
)
+
latents
latents
=
ff
(
latents
)
+
latents
latents
=
self
.
proj_out
(
latents
)
latents
=
self
.
norm_out
(
latents
)
# B L C or B (T L) C
return
latents
class
CameraPoseQueryTransformer
(
nn
.
Module
):
def
__init__
(
self
,
dim
=
1024
,
depth
=
8
,
dim_head
=
64
,
heads
=
16
,
num_queries
=
8
,
embedding_dim
=
768
,
output_dim
=
1024
,
ff_mult
=
4
,
num_views
=
None
,
use_multi_view_attention
=
True
,
):
super
().
__init__
()
self
.
num_queries
=
num_queries
self
.
num_views
=
num_views
assert
num_views
is
not
None
,
"video_length must be given."
self
.
use_multi_view_attention
=
use_multi_view_attention
self
.
camera_pose_embedding_layers
=
nn
.
Sequential
(
nn
.
Linear
(
12
,
dim
),
nn
.
SiLU
(),
nn
.
Linear
(
dim
,
dim
),
nn
.
SiLU
(),
nn
.
Linear
(
dim
,
dim
),
)
nn
.
init
.
zeros_
(
self
.
camera_pose_embedding_layers
[
-
1
].
weight
)
nn
.
init
.
zeros_
(
self
.
camera_pose_embedding_layers
[
-
1
].
bias
)
self
.
latents
=
nn
.
Parameter
(
torch
.
randn
(
1
,
num_views
*
num_queries
,
dim
)
/
dim
**
0.5
)
self
.
proj_in
=
nn
.
Linear
(
embedding_dim
,
dim
)
self
.
proj_out
=
nn
.
Linear
(
dim
,
output_dim
)
self
.
norm_out
=
nn
.
LayerNorm
(
output_dim
)
self
.
layers
=
nn
.
ModuleList
([])
for
_
in
range
(
depth
):
self
.
layers
.
append
(
nn
.
ModuleList
(
[
PerceiverAttention
(
dim
=
dim
,
dim_head
=
dim_head
,
heads
=
heads
),
FeedForward
(
dim
=
dim
,
mult
=
ff_mult
),
]
)
)
def
forward
(
self
,
x
,
camera_poses
):
# camera_poses: (b, t, 12)
batch_size
,
num_views
,
_
=
camera_poses
.
shape
# latents: (1, t*q, d) -> (b, t*q, d)
latents
=
self
.
latents
.
repeat
(
batch_size
,
1
,
1
)
x
=
self
.
proj_in
(
x
)
# camera_poses: (b*t, 12)
camera_poses
=
rearrange
(
camera_poses
,
"b t d -> (b t) d"
,
t
=
num_views
)
camera_poses
=
self
.
camera_pose_embedding_layers
(
camera_poses
)
# camera_poses: (b*t, d)
# camera_poses: (b, t, d)
camera_poses
=
rearrange
(
camera_poses
,
"(b t) d -> b t d"
,
t
=
num_views
)
# camera_poses: (b, t*q, d)
camera_poses
=
repeat
(
camera_poses
,
"b t d -> b (t q) d"
,
q
=
self
.
num_queries
)
latents
=
latents
+
camera_poses
# b, t*q, d
latents
=
rearrange
(
latents
,
"b (t q) d -> (b t) q d"
,
b
=
batch_size
,
t
=
num_views
,
q
=
self
.
num_queries
,
)
# (b*t, q, d)
_
,
x_seq_size
,
_
=
x
.
shape
for
layer_idx
,
(
attn
,
ff
)
in
enumerate
(
self
.
layers
):
if
self
.
use_multi_view_attention
and
layer_idx
%
2
==
1
:
# latents: (b*t, q, d)
latents
=
rearrange
(
latents
,
"(b t) q d -> b (t q) d"
,
b
=
batch_size
,
t
=
num_views
,
q
=
self
.
num_queries
,
)
# x: (b*t, s, d)
x
=
rearrange
(
x
,
"(b t) s d -> b (t s) d"
,
b
=
batch_size
,
t
=
num_views
,
s
=
x_seq_size
)
# print("After rearrange: latents.shape=", latents.shape)
# print("After rearrange: x.shape=", camera_poses.shape)
latents
=
attn
(
x
,
latents
)
+
latents
latents
=
ff
(
latents
)
+
latents
if
self
.
use_multi_view_attention
and
layer_idx
%
2
==
1
:
# latents: (b*q, t, d)
latents
=
rearrange
(
latents
,
"b (t q) d -> (b t) q d"
,
b
=
batch_size
,
t
=
num_views
,
q
=
self
.
num_queries
,
)
# x: (b*s, t, d)
x
=
rearrange
(
x
,
"b (t s) d -> (b t) s d"
,
b
=
batch_size
,
t
=
num_views
,
s
=
x_seq_size
)
latents
=
self
.
proj_out
(
latents
)
latents
=
self
.
norm_out
(
latents
)
# B L C or B (T L) C
return
latents
NVComposer/core/modules/networks/ae_modules.py
0 → 100755
View file @
30af93f2
# pytorch_diffusion + derived encoder decoder
import
math
import
torch
import
numpy
as
np
import
torch.nn
as
nn
from
einops
import
rearrange
from
utils.utils
import
instantiate_from_config
from
core.modules.attention
import
LinearAttention
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
,
num_groups
=
32
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
LinAttnBlock
(
LinearAttention
):
"""to match AttnBlock usage"""
def
__init__
(
self
,
in_channels
):
super
().
__init__
(
dim
=
in_channels
,
heads
=
1
,
dim_head
=
in_channels
)
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
# bcl
q
=
q
.
permute
(
0
,
2
,
1
)
# bcl -> blc l=hw
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# bcl
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
torch
.
bmm
(
v
,
w_
)
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
def
make_attn
(
in_channels
,
attn_type
=
"vanilla"
):
assert
attn_type
in
[
"vanilla"
,
"linear"
,
"none"
],
f
"attn_type
{
attn_type
}
unknown"
if
attn_type
==
"vanilla"
:
return
AttnBlock
(
in_channels
)
elif
attn_type
==
"none"
:
return
nn
.
Identity
(
in_channels
)
else
:
return
LinAttnBlock
(
in_channels
)
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
self
.
in_channels
=
in_channels
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
self
.
in_channels
=
in_channels
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
def
get_timestep_embedding
(
time_steps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert
len
(
time_steps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
time_steps
.
device
)
emb
=
time_steps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
use_timestep
=
True
,
use_linear_attn
=
False
,
attn_type
=
"vanilla"
,
):
super
().
__init__
()
if
use_linear_attn
:
attn_type
=
"linear"
self
.
ch
=
ch
self
.
temb_ch
=
self
.
ch
*
4
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
use_timestep
=
use_timestep
if
self
.
use_timestep
:
# timestep embedding
self
.
temb
=
nn
.
Module
()
self
.
temb
.
dense
=
nn
.
ModuleList
(
[
torch
.
nn
.
Linear
(
self
.
ch
,
self
.
temb_ch
),
torch
.
nn
.
Linear
(
self
.
temb_ch
,
self
.
temb_ch
),
]
)
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_attn
(
block_in
,
attn_type
=
attn_type
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
self
.
mid
.
attn_1
=
make_attn
(
block_in
,
attn_type
=
attn_type
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
skip_in
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
if
i_block
==
self
.
num_res_blocks
:
skip_in
=
ch
*
in_ch_mult
[
i_level
]
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
+
skip_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_attn
(
block_in
,
attn_type
=
attn_type
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
,
t
=
None
,
context
=
None
):
# assert x.shape[2] == x.shape[3] == self.resolution
if
context
is
not
None
:
# assume aligned context, cat along channel axis
x
=
torch
.
cat
((
x
,
context
),
dim
=
1
)
if
self
.
use_timestep
:
# timestep embedding
assert
t
is
not
None
temb
=
get_timestep_embedding
(
t
,
self
.
ch
)
temb
=
self
.
temb
.
dense
[
0
](
temb
)
temb
=
nonlinearity
(
temb
)
temb
=
self
.
temb
.
dense
[
1
](
temb
)
else
:
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
),
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
def
get_last_layer
(
self
):
return
self
.
conv_out
.
weight
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
double_z
=
True
,
use_linear_attn
=
False
,
attn_type
=
"vanilla"
,
**
ignore_kwargs
,
):
super
().
__init__
()
if
use_linear_attn
:
attn_type
=
"linear"
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
in_ch_mult
=
in_ch_mult
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_attn
(
block_in
,
attn_type
=
attn_type
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
self
.
mid
.
attn_1
=
make_attn
(
block_in
,
attn_type
=
attn_type
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
)
def
forward
(
self
,
x
):
# timestep embedding
temb
=
None
# print(f'encoder-input={x.shape}')
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
# print(f'encoder-conv in feat={hs[0].shape}')
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
# print(f'encoder-down feat={h.shape}')
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
# print(f'encoder-downsample (input)={hs[-1].shape}')
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# print(f'encoder-downsample (output)={hs[-1].shape}')
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
# print(f'encoder-mid1 feat={h.shape}')
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# print(f'encoder-mid2 feat={h.shape}')
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
# print(f'end feat={h.shape}')
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
give_pre_end
=
False
,
tanh_out
=
False
,
use_linear_attn
=
False
,
attn_type
=
"vanilla"
,
**
ignored_kwargs
,
):
super
().
__init__
()
if
use_linear_attn
:
attn_type
=
"linear"
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
give_pre_end
=
give_pre_end
self
.
tanh_out
=
tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
# print("AE working on z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
self
.
mid
.
attn_1
=
make_attn
(
block_in
,
attn_type
=
attn_type
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_attn
(
block_in
,
attn_type
=
attn_type
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
z
):
# assert z.shape[1:] == self.z_shape[1:]
self
.
last_z_shape
=
z
.
shape
# print(f'decoder-input={z.shape}')
# timestep embedding
temb
=
None
# z to block_in
h
=
self
.
conv_in
(
z
)
# print(f'decoder-conv in feat={h.shape}')
# middle
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# print(f'decoder-mid feat={h.shape}')
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
# print(f'decoder-up feat={h.shape}')
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# print(f'decoder-upsample feat={h.shape}')
# end
if
self
.
give_pre_end
:
return
h
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
# print(f'decoder-conv_out feat={h.shape}')
if
self
.
tanh_out
:
h
=
torch
.
tanh
(
h
)
return
h
class
SimpleDecoder
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
*
args
,
**
kwargs
):
super
().
__init__
()
self
.
model
=
nn
.
ModuleList
(
[
nn
.
Conv2d
(
in_channels
,
in_channels
,
1
),
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
2
*
in_channels
,
temb_channels
=
0
,
dropout
=
0.0
,
),
ResnetBlock
(
in_channels
=
2
*
in_channels
,
out_channels
=
4
*
in_channels
,
temb_channels
=
0
,
dropout
=
0.0
,
),
ResnetBlock
(
in_channels
=
4
*
in_channels
,
out_channels
=
2
*
in_channels
,
temb_channels
=
0
,
dropout
=
0.0
,
),
nn
.
Conv2d
(
2
*
in_channels
,
in_channels
,
1
),
Upsample
(
in_channels
,
with_conv
=
True
),
]
)
# end
self
.
norm_out
=
Normalize
(
in_channels
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
for
i
,
layer
in
enumerate
(
self
.
model
):
if
i
in
[
1
,
2
,
3
]:
x
=
layer
(
x
,
None
)
else
:
x
=
layer
(
x
)
h
=
self
.
norm_out
(
x
)
h
=
nonlinearity
(
h
)
x
=
self
.
conv_out
(
h
)
return
x
class
UpsampleDecoder
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
ch
,
num_res_blocks
,
resolution
,
ch_mult
=
(
2
,
2
),
dropout
=
0.0
,
):
super
().
__init__
()
# upsampling
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
block_in
=
in_channels
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
res_blocks
=
nn
.
ModuleList
()
self
.
upsample_blocks
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
res_block
=
[]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
res_block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
self
.
res_blocks
.
append
(
nn
.
ModuleList
(
res_block
))
if
i_level
!=
self
.
num_resolutions
-
1
:
self
.
upsample_blocks
.
append
(
Upsample
(
block_in
,
True
))
curr_res
=
curr_res
*
2
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
# upsampling
h
=
x
for
k
,
i_level
in
enumerate
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
res_blocks
[
i_level
][
i_block
](
h
,
None
)
if
i_level
!=
self
.
num_resolutions
-
1
:
h
=
self
.
upsample_blocks
[
k
](
h
)
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
LatentRescaler
(
nn
.
Module
):
def
__init__
(
self
,
factor
,
in_channels
,
mid_channels
,
out_channels
,
depth
=
2
):
super
().
__init__
()
# residual block, interpolate, residual block
self
.
factor
=
factor
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
mid_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
res_block1
=
nn
.
ModuleList
(
[
ResnetBlock
(
in_channels
=
mid_channels
,
out_channels
=
mid_channels
,
temb_channels
=
0
,
dropout
=
0.0
,
)
for
_
in
range
(
depth
)
]
)
self
.
attn
=
AttnBlock
(
mid_channels
)
self
.
res_block2
=
nn
.
ModuleList
(
[
ResnetBlock
(
in_channels
=
mid_channels
,
out_channels
=
mid_channels
,
temb_channels
=
0
,
dropout
=
0.0
,
)
for
_
in
range
(
depth
)
]
)
self
.
conv_out
=
nn
.
Conv2d
(
mid_channels
,
out_channels
,
kernel_size
=
1
,
)
def
forward
(
self
,
x
):
x
=
self
.
conv_in
(
x
)
for
block
in
self
.
res_block1
:
x
=
block
(
x
,
None
)
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
size
=
(
int
(
round
(
x
.
shape
[
2
]
*
self
.
factor
)),
int
(
round
(
x
.
shape
[
3
]
*
self
.
factor
)),
),
)
x
=
self
.
attn
(
x
)
for
block
in
self
.
res_block2
:
x
=
block
(
x
,
None
)
x
=
self
.
conv_out
(
x
)
return
x
class
MergedRescaleEncoder
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
ch
,
resolution
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
ch_mult
=
(
1
,
2
,
4
,
8
),
rescale_factor
=
1.0
,
rescale_module_depth
=
1
,
):
super
().
__init__
()
intermediate_chn
=
ch
*
ch_mult
[
-
1
]
self
.
encoder
=
Encoder
(
in_channels
=
in_channels
,
num_res_blocks
=
num_res_blocks
,
ch
=
ch
,
ch_mult
=
ch_mult
,
z_channels
=
intermediate_chn
,
double_z
=
False
,
resolution
=
resolution
,
attn_resolutions
=
attn_resolutions
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
out_ch
=
None
,
)
self
.
rescaler
=
LatentRescaler
(
factor
=
rescale_factor
,
in_channels
=
intermediate_chn
,
mid_channels
=
intermediate_chn
,
out_channels
=
out_ch
,
depth
=
rescale_module_depth
,
)
def
forward
(
self
,
x
):
x
=
self
.
encoder
(
x
)
x
=
self
.
rescaler
(
x
)
return
x
class
MergedRescaleDecoder
(
nn
.
Module
):
def
__init__
(
self
,
z_channels
,
out_ch
,
resolution
,
num_res_blocks
,
attn_resolutions
,
ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
resamp_with_conv
=
True
,
rescale_factor
=
1.0
,
rescale_module_depth
=
1
,
):
super
().
__init__
()
tmp_chn
=
z_channels
*
ch_mult
[
-
1
]
self
.
decoder
=
Decoder
(
out_ch
=
out_ch
,
z_channels
=
tmp_chn
,
attn_resolutions
=
attn_resolutions
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
in_channels
=
None
,
num_res_blocks
=
num_res_blocks
,
ch_mult
=
ch_mult
,
resolution
=
resolution
,
ch
=
ch
,
)
self
.
rescaler
=
LatentRescaler
(
factor
=
rescale_factor
,
in_channels
=
z_channels
,
mid_channels
=
tmp_chn
,
out_channels
=
tmp_chn
,
depth
=
rescale_module_depth
,
)
def
forward
(
self
,
x
):
x
=
self
.
rescaler
(
x
)
x
=
self
.
decoder
(
x
)
return
x
class
Upsampler
(
nn
.
Module
):
def
__init__
(
self
,
in_size
,
out_size
,
in_channels
,
out_channels
,
ch_mult
=
2
):
super
().
__init__
()
assert
out_size
>=
in_size
num_blocks
=
int
(
np
.
log2
(
out_size
//
in_size
))
+
1
factor_up
=
1.0
+
(
out_size
%
in_size
)
print
(
f
"Building
{
self
.
__class__
.
__name__
}
with in_size:
{
in_size
}
--> out_size
{
out_size
}
and factor
{
factor_up
}
"
)
self
.
rescaler
=
LatentRescaler
(
factor
=
factor_up
,
in_channels
=
in_channels
,
mid_channels
=
2
*
in_channels
,
out_channels
=
in_channels
,
)
self
.
decoder
=
Decoder
(
out_ch
=
out_channels
,
resolution
=
out_size
,
z_channels
=
in_channels
,
num_res_blocks
=
2
,
attn_resolutions
=
[],
in_channels
=
None
,
ch
=
in_channels
,
ch_mult
=
[
ch_mult
for
_
in
range
(
num_blocks
)],
)
def
forward
(
self
,
x
):
x
=
self
.
rescaler
(
x
)
x
=
self
.
decoder
(
x
)
return
x
class
Resize
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
=
None
,
learned
=
False
,
mode
=
"bilinear"
):
super
().
__init__
()
self
.
with_conv
=
learned
self
.
mode
=
mode
if
self
.
with_conv
:
print
(
f
"Note:
{
self
.
__class__
.
__name
}
uses learned downsampling and will ignore the fixed
{
mode
}
mode"
)
raise
NotImplementedError
()
assert
in_channels
is
not
None
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
4
,
stride
=
2
,
padding
=
1
)
def
forward
(
self
,
x
,
scale_factor
=
1.0
):
if
scale_factor
==
1.0
:
return
x
else
:
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
mode
=
self
.
mode
,
align_corners
=
False
,
scale_factor
=
scale_factor
)
return
x
class
FirstStagePostProcessor
(
nn
.
Module
):
def
__init__
(
self
,
ch_mult
:
list
,
in_channels
,
pretrained_model
:
nn
.
Module
=
None
,
reshape
=
False
,
n_channels
=
None
,
dropout
=
0.0
,
pretrained_config
=
None
,
):
super
().
__init__
()
if
pretrained_config
is
None
:
assert
(
pretrained_model
is
not
None
),
'Either "pretrained_model" or "pretrained_config" must not be None'
self
.
pretrained_model
=
pretrained_model
else
:
assert
(
pretrained_config
is
not
None
),
'Either "pretrained_model" or "pretrained_config" must not be None'
self
.
instantiate_pretrained
(
pretrained_config
)
self
.
do_reshape
=
reshape
if
n_channels
is
None
:
n_channels
=
self
.
pretrained_model
.
encoder
.
ch
self
.
proj_norm
=
Normalize
(
in_channels
,
num_groups
=
in_channels
//
2
)
self
.
proj
=
nn
.
Conv2d
(
in_channels
,
n_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
blocks
=
[]
downs
=
[]
ch_in
=
n_channels
for
m
in
ch_mult
:
blocks
.
append
(
ResnetBlock
(
in_channels
=
ch_in
,
out_channels
=
m
*
n_channels
,
dropout
=
dropout
)
)
ch_in
=
m
*
n_channels
downs
.
append
(
Downsample
(
ch_in
,
with_conv
=
False
))
self
.
model
=
nn
.
ModuleList
(
blocks
)
self
.
downsampler
=
nn
.
ModuleList
(
downs
)
def
instantiate_pretrained
(
self
,
config
):
model
=
instantiate_from_config
(
config
)
self
.
pretrained_model
=
model
.
eval
()
# self.pretrained_model.train = False
for
param
in
self
.
pretrained_model
.
parameters
():
param
.
requires_grad
=
False
@
torch
.
no_grad
()
def
encode_with_pretrained
(
self
,
x
):
c
=
self
.
pretrained_model
.
encode
(
x
)
if
isinstance
(
c
,
DiagonalGaussianDistribution
):
c
=
c
.
mode
()
return
c
def
forward
(
self
,
x
):
z_fs
=
self
.
encode_with_pretrained
(
x
)
z
=
self
.
proj_norm
(
z_fs
)
z
=
self
.
proj
(
z
)
z
=
nonlinearity
(
z
)
for
submodel
,
downmodel
in
zip
(
self
.
model
,
self
.
downsampler
):
z
=
submodel
(
z
,
temb
=
None
)
z
=
downmodel
(
z
)
if
self
.
do_reshape
:
z
=
rearrange
(
z
,
"b c h w -> b (h w) c"
)
return
z
NVComposer/core/modules/networks/unet_modules.py
0 → 100755
View file @
30af93f2
from
functools
import
partial
from
abc
import
abstractmethod
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
import
torch.nn.functional
as
F
from
core.models.utils_diffusion
import
timestep_embedding
from
core.common
import
gradient_checkpoint
from
core.basics
import
zero_module
,
conv_nd
,
linear
,
avg_pool_nd
,
normalization
from
core.modules.attention
import
SpatialTransformer
,
TemporalTransformer
TASK_IDX_IMAGE
=
0
TASK_IDX_RAY
=
1
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def
forward
(
self
,
x
,
emb
,
context
=
None
,
batch_size
=
None
,
with_lora
=
False
,
time_steps
=
None
):
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
x
=
layer
(
x
,
emb
,
batch_size
=
batch_size
)
elif
isinstance
(
layer
,
SpatialTransformer
):
x
=
layer
(
x
,
context
,
with_lora
=
with_lora
)
elif
isinstance
(
layer
,
TemporalTransformer
):
x
=
rearrange
(
x
,
"(b f) c h w -> b c f h w"
,
b
=
batch_size
)
x
=
layer
(
x
,
context
,
with_lora
=
with_lora
,
time_steps
=
time_steps
)
x
=
rearrange
(
x
,
"b c f h w -> (b f) c h w"
)
else
:
x
=
layer
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
,
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
op
(
x
)
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
:param use_temporal_conv: if True, use the temporal convolution.
:param use_image_dataset: if True, the temporal parameters will not be optimized.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
use_conv
=
False
,
up
=
False
,
down
=
False
,
use_temporal_conv
=
False
,
tempspatial_aware
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
use_temporal_conv
=
use_temporal_conv
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
),
nn
.
SiLU
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
nn
.
Conv2d
(
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
if
self
.
use_temporal_conv
:
self
.
temopral_conv
=
TemporalConvBlock
(
self
.
out_channels
,
self
.
out_channels
,
dropout
=
0.1
,
spatial_aware
=
tempspatial_aware
,
)
def
forward
(
self
,
x
,
emb
,
batch_size
=
None
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
input_tuple
=
(
x
,
emb
)
if
batch_size
:
forward_batchsize
=
partial
(
self
.
_forward
,
batch_size
=
batch_size
)
return
gradient_checkpoint
(
forward_batchsize
,
input_tuple
,
self
.
parameters
(),
self
.
use_checkpoint
)
return
gradient_checkpoint
(
self
.
_forward
,
input_tuple
,
self
.
parameters
(),
self
.
use_checkpoint
)
def
_forward
(
self
,
x
,
emb
,
batch_size
=
None
):
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
h
=
self
.
skip_connection
(
x
)
+
h
if
self
.
use_temporal_conv
and
batch_size
:
h
=
rearrange
(
h
,
"(b t) c h w -> b c t h w"
,
b
=
batch_size
)
h
=
self
.
temopral_conv
(
h
)
h
=
rearrange
(
h
,
"b c t h w -> (b t) c h w"
)
return
h
class
TemporalConvBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
=
None
,
dropout
=
0.0
,
spatial_aware
=
False
):
super
(
TemporalConvBlock
,
self
).
__init__
()
if
out_channels
is
None
:
out_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
th_kernel_shape
=
(
3
,
1
,
1
)
if
not
spatial_aware
else
(
3
,
3
,
1
)
th_padding_shape
=
(
1
,
0
,
0
)
if
not
spatial_aware
else
(
1
,
1
,
0
)
tw_kernel_shape
=
(
3
,
1
,
1
)
if
not
spatial_aware
else
(
3
,
1
,
3
)
tw_padding_shape
=
(
1
,
0
,
0
)
if
not
spatial_aware
else
(
1
,
0
,
1
)
# conv layers
self
.
conv1
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
in_channels
),
nn
.
SiLU
(),
nn
.
Conv3d
(
in_channels
,
out_channels
,
th_kernel_shape
,
padding
=
th_padding_shape
),
)
self
.
conv2
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
out_channels
),
nn
.
SiLU
(),
nn
.
Dropout
(
dropout
),
nn
.
Conv3d
(
out_channels
,
in_channels
,
tw_kernel_shape
,
padding
=
tw_padding_shape
),
)
self
.
conv3
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
out_channels
),
nn
.
SiLU
(),
nn
.
Dropout
(
dropout
),
nn
.
Conv3d
(
out_channels
,
in_channels
,
th_kernel_shape
,
padding
=
th_padding_shape
),
)
self
.
conv4
=
nn
.
Sequential
(
nn
.
GroupNorm
(
32
,
out_channels
),
nn
.
SiLU
(),
nn
.
Dropout
(
dropout
),
nn
.
Conv3d
(
out_channels
,
in_channels
,
tw_kernel_shape
,
padding
=
tw_padding_shape
),
)
# zero out the last layer params,so the conv block is identity
nn
.
init
.
zeros_
(
self
.
conv4
[
-
1
].
weight
)
nn
.
init
.
zeros_
(
self
.
conv4
[
-
1
].
bias
)
def
forward
(
self
,
x
):
identity
=
x
x
=
self
.
conv1
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
conv3
(
x
)
x
=
self
.
conv4
(
x
)
return
identity
+
x
class
UNetModel
(
nn
.
Module
):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: in_channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def
__init__
(
self
,
in_channels
,
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0.0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
context_dim
=
None
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
num_heads
=-
1
,
num_head_channels
=-
1
,
transformer_depth
=
1
,
use_linear
=
False
,
use_checkpoint
=
False
,
temporal_conv
=
False
,
tempspatial_aware
=
False
,
temporal_attention
=
True
,
use_relative_position
=
True
,
use_causal_attention
=
False
,
temporal_length
=
None
,
use_fp16
=
False
,
addition_attention
=
False
,
temporal_selfatt_only
=
True
,
image_cross_attention
=
False
,
image_cross_attention_scale_learnable
=
False
,
default_fs
=
4
,
fs_condition
=
False
,
use_spatial_temporal_attention
=
False
,
# >>> Extra Ray Options
use_addition_ray_output_head
=
False
,
ray_channels
=
6
,
use_lora_for_rays_in_output_blocks
=
False
,
use_task_embedding
=
False
,
use_ray_decoder
=
False
,
use_ray_decoder_residual
=
False
,
full_spatial_temporal_attention
=
False
,
enhance_multi_view_correspondence
=
False
,
camera_pose_condition
=
False
,
use_feature_alignment
=
False
,
):
super
(
UNetModel
,
self
).
__init__
()
if
num_heads
==
-
1
:
assert
(
num_head_channels
!=
-
1
),
"Either num_heads or num_head_channels has to be set"
if
num_head_channels
==
-
1
:
assert
(
num_heads
!=
-
1
),
"Either num_heads or num_head_channels has to be set"
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
temporal_attention
=
temporal_attention
time_embed_dim
=
model_channels
*
4
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
torch
.
float16
if
use_fp16
else
torch
.
float32
temporal_self_att_only
=
True
self
.
addition_attention
=
addition_attention
self
.
temporal_length
=
temporal_length
self
.
image_cross_attention
=
image_cross_attention
self
.
image_cross_attention_scale_learnable
=
(
image_cross_attention_scale_learnable
)
self
.
default_fs
=
default_fs
self
.
fs_condition
=
fs_condition
self
.
use_spatial_temporal_attention
=
use_spatial_temporal_attention
# >>> Extra Ray Options
self
.
use_addition_ray_output_head
=
use_addition_ray_output_head
self
.
use_lora_for_rays_in_output_blocks
=
use_lora_for_rays_in_output_blocks
if
self
.
use_lora_for_rays_in_output_blocks
:
assert
(
use_addition_ray_output_head
),
"`use_addition_ray_output_head` is required to be True when using LoRA for rays in output blocks."
assert
(
not
use_task_embedding
),
"`use_task_embedding` cannot be True when `use_lora_for_rays_in_output_blocks` is enabled."
if
self
.
use_addition_ray_output_head
:
print
(
"Using additional ray output head..."
)
assert
(
self
.
out_channels
==
4
)
or
(
4
+
ray_channels
==
self
.
out_channels
),
f
"`out_channels`=
{
out_channels
}
is invalid."
self
.
out_channels
=
4
out_channels
=
4
self
.
ray_channels
=
ray_channels
self
.
use_ray_decoder
=
use_ray_decoder
if
use_ray_decoder
:
assert
(
not
use_task_embedding
),
"`use_task_embedding` cannot be True when `use_ray_decoder_layers` is enabled."
assert
(
use_addition_ray_output_head
),
"`use_addition_ray_output_head` must be True when `use_ray_decoder_layers` is enabled."
self
.
use_ray_decoder_residual
=
use_ray_decoder_residual
# >>> Time/Task Embedding Blocks
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
if
fs_condition
:
self
.
fps_embedding
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
nn
.
init
.
zeros_
(
self
.
fps_embedding
[
-
1
].
weight
)
nn
.
init
.
zeros_
(
self
.
fps_embedding
[
-
1
].
bias
)
if
camera_pose_condition
:
self
.
camera_pose_condition
=
True
self
.
camera_pose_embedding
=
nn
.
Sequential
(
linear
(
12
,
model_channels
),
nn
.
SiLU
(),
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
nn
.
init
.
zeros_
(
self
.
camera_pose_embedding
[
-
1
].
weight
)
nn
.
init
.
zeros_
(
self
.
camera_pose_embedding
[
-
1
].
bias
)
self
.
use_task_embedding
=
use_task_embedding
if
use_task_embedding
:
assert
(
not
use_lora_for_rays_in_output_blocks
),
"`use_lora_for_rays_in_output_blocks` and `use_task_embedding` cannot be True at the same time."
assert
(
use_addition_ray_output_head
),
"`use_addition_ray_output_head` is required to be True when `use_task_embedding` is enabled."
self
.
task_embedding
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
nn
.
init
.
zeros_
(
self
.
task_embedding
[
-
1
].
weight
)
nn
.
init
.
zeros_
(
self
.
task_embedding
[
-
1
].
bias
)
self
.
task_parameters
=
nn
.
ParameterList
(
[
nn
.
Parameter
(
torch
.
zeros
(
size
=
[
model_channels
],
requires_grad
=
True
)
),
nn
.
Parameter
(
torch
.
zeros
(
size
=
[
model_channels
],
requires_grad
=
True
)
),
]
)
# >>> Input Block
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
)
)
]
)
if
self
.
addition_attention
:
self
.
init_attn
=
TimestepEmbedSequential
(
TemporalTransformer
(
model_channels
,
n_heads
=
8
,
d_head
=
num_head_channels
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
use_checkpoint
=
use_checkpoint
,
only_self_att
=
temporal_selfatt_only
,
causal_attention
=
False
,
relative_position
=
use_relative_position
,
temporal_length
=
temporal_length
,
)
)
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
tempspatial_aware
=
tempspatial_aware
,
use_temporal_conv
=
temporal_conv
,
)
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
layers
.
append
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
use_linear
=
use_linear
,
use_checkpoint
=
use_checkpoint
,
disable_self_attn
=
False
,
video_length
=
temporal_length
,
image_cross_attention
=
self
.
image_cross_attention
,
image_cross_attention_scale_learnable
=
self
.
image_cross_attention_scale_learnable
,
)
)
if
self
.
temporal_attention
:
layers
.
append
(
TemporalTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
use_linear
=
use_linear
,
use_checkpoint
=
use_checkpoint
,
only_self_att
=
temporal_self_att_only
,
causal_attention
=
use_causal_attention
,
relative_position
=
use_relative_position
,
temporal_length
=
temporal_length
,
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
ds
*=
2
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
layers
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
tempspatial_aware
=
tempspatial_aware
,
use_temporal_conv
=
temporal_conv
,
),
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
use_linear
=
use_linear
,
use_checkpoint
=
use_checkpoint
,
disable_self_attn
=
False
,
video_length
=
temporal_length
,
image_cross_attention
=
self
.
image_cross_attention
,
image_cross_attention_scale_learnable
=
self
.
image_cross_attention_scale_learnable
,
),
]
if
self
.
temporal_attention
:
layers
.
append
(
TemporalTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
use_linear
=
use_linear
,
use_checkpoint
=
use_checkpoint
,
only_self_att
=
temporal_self_att_only
,
causal_attention
=
use_causal_attention
,
relative_position
=
use_relative_position
,
temporal_length
=
temporal_length
,
)
)
layers
.
append
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
tempspatial_aware
=
tempspatial_aware
,
use_temporal_conv
=
temporal_conv
,
)
)
# >>> Middle Block
self
.
middle_block
=
TimestepEmbedSequential
(
*
layers
)
# >>> Ray Decoder
if
use_ray_decoder
:
self
.
ray_decoder_blocks
=
nn
.
ModuleList
([])
# >>> Output Block
is_first_layer
=
True
self
.
output_blocks
=
nn
.
ModuleList
([])
for
level
,
mult
in
list
(
enumerate
(
channel_mult
))[::
-
1
]:
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
layers
=
[
ResBlock
(
ch
+
ich
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
tempspatial_aware
=
tempspatial_aware
,
use_temporal_conv
=
temporal_conv
,
)
]
if
use_ray_decoder
:
if
self
.
use_ray_decoder_residual
:
ray_residual_ch
=
ich
else
:
ray_residual_ch
=
0
ray_decoder_layers
=
[
ResBlock
(
(
ch
if
is_first_layer
else
(
ch
//
10
))
+
ray_residual_ch
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
//
10
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
tempspatial_aware
=
tempspatial_aware
,
use_temporal_conv
=
True
,
)
]
is_first_layer
=
False
ch
=
model_channels
*
mult
if
ds
in
attention_resolutions
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
layers
.
append
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
use_linear
=
use_linear
,
use_checkpoint
=
use_checkpoint
,
disable_self_attn
=
False
,
video_length
=
temporal_length
,
image_cross_attention
=
self
.
image_cross_attention
,
image_cross_attention_scale_learnable
=
self
.
image_cross_attention_scale_learnable
,
enable_lora
=
self
.
use_lora_for_rays_in_output_blocks
,
)
)
if
self
.
temporal_attention
:
layers
.
append
(
TemporalTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
use_linear
=
use_linear
,
use_checkpoint
=
use_checkpoint
,
only_self_att
=
temporal_self_att_only
,
causal_attention
=
use_causal_attention
,
relative_position
=
use_relative_position
,
temporal_length
=
temporal_length
,
use_extra_spatial_temporal_self_attention
=
use_spatial_temporal_attention
,
enable_lora
=
self
.
use_lora_for_rays_in_output_blocks
,
full_spatial_temporal_attention
=
full_spatial_temporal_attention
,
enhance_multi_view_correspondence
=
enhance_multi_view_correspondence
,
)
)
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
# out_ray_ch = ray_ch
layers
.
append
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
up
=
True
,
)
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
if
use_ray_decoder
:
ray_decoder_layers
.
append
(
ResBlock
(
ch
//
10
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
//
10
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
up
=
True
,
)
if
resblock_updown
else
Upsample
(
ch
//
10
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
//
10
,
)
)
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
if
use_ray_decoder
:
self
.
ray_decoder_blocks
.
append
(
TimestepEmbedSequential
(
*
ray_decoder_layers
)
)
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
nn
.
SiLU
(),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
use_addition_ray_output_head
:
ray_model_channels
=
model_channels
//
10
self
.
ray_output_head
=
nn
.
Sequential
(
normalization
(
ray_model_channels
),
nn
.
SiLU
(),
conv_nd
(
dims
,
ray_model_channels
,
ray_model_channels
,
3
,
padding
=
1
),
nn
.
SiLU
(),
conv_nd
(
dims
,
ray_model_channels
,
ray_model_channels
,
3
,
padding
=
1
),
nn
.
SiLU
(),
zero_module
(
conv_nd
(
dims
,
ray_model_channels
,
self
.
ray_channels
,
3
,
padding
=
1
)
),
)
self
.
use_feature_alignment
=
use_feature_alignment
if
self
.
use_feature_alignment
:
self
.
feature_alignment_adapter
=
FeatureAlignmentAdapter
(
time_embed_dim
=
time_embed_dim
,
use_checkpoint
=
use_checkpoint
)
def
forward
(
self
,
x
,
time_steps
,
context
=
None
,
features_adapter
=
None
,
fs
=
None
,
task_idx
=
None
,
camera_poses
=
None
,
return_input_block_features
=
False
,
return_middle_feature
=
False
,
return_output_block_features
=
False
,
**
kwargs
,
):
intermediate_features
=
{}
if
return_input_block_features
:
intermediate_features
[
"input"
]
=
[]
if
return_output_block_features
:
intermediate_features
[
"output"
]
=
[]
b
,
t
,
_
,
_
,
_
=
x
.
shape
t_emb
=
timestep_embedding
(
time_steps
,
self
.
model_channels
,
repeat_only
=
False
).
type
(
x
.
dtype
)
emb
=
self
.
time_embed
(
t_emb
)
# repeat t times for context [(b t) 77 768] & time embedding
# check if we use per-frame image conditioning
_
,
l_context
,
_
=
context
.
shape
if
l_context
==
77
+
t
*
16
:
# !!! HARD CODE here
context_text
,
context_img
=
context
[:,
:
77
,
:],
context
[:,
77
:,
:]
context_text
=
context_text
.
repeat_interleave
(
repeats
=
t
,
dim
=
0
)
context_img
=
rearrange
(
context_img
,
"b (t l) c -> (b t) l c"
,
t
=
t
)
context
=
torch
.
cat
([
context_text
,
context_img
],
dim
=
1
)
else
:
context
=
context
.
repeat_interleave
(
repeats
=
t
,
dim
=
0
)
emb
=
emb
.
repeat_interleave
(
repeats
=
t
,
dim
=
0
)
# always in shape (b t) c h w, except for temporal layer
x
=
rearrange
(
x
,
"b t c h w -> (b t) c h w"
)
# combine emb
if
self
.
fs_condition
:
if
fs
is
None
:
fs
=
torch
.
tensor
(
[
self
.
default_fs
]
*
b
,
dtype
=
torch
.
long
,
device
=
x
.
device
)
fs_emb
=
timestep_embedding
(
fs
,
self
.
model_channels
,
repeat_only
=
False
).
type
(
x
.
dtype
)
fs_embed
=
self
.
fps_embedding
(
fs_emb
)
fs_embed
=
fs_embed
.
repeat_interleave
(
repeats
=
t
,
dim
=
0
)
emb
=
emb
+
fs_embed
if
self
.
camera_pose_condition
:
# camera_poses: (b, t, 12)
camera_poses
=
rearrange
(
camera_poses
,
"b t x y -> (b t) (x y)"
)
# x=3, y=4
camera_poses_embed
=
self
.
camera_pose_embedding
(
camera_poses
)
emb
=
emb
+
camera_poses_embed
if
self
.
use_task_embedding
:
assert
(
task_idx
is
not
None
),
"`task_idx` should not be None when `use_task_embedding` is enabled."
task_embed
=
self
.
task_embedding
(
self
.
task_parameters
[
task_idx
]
.
reshape
(
1
,
self
.
model_channels
)
.
repeat
(
b
,
1
)
)
task_embed
=
task_embed
.
repeat_interleave
(
repeats
=
t
,
dim
=
0
)
emb
=
emb
+
task_embed
h
=
x
.
type
(
self
.
dtype
)
adapter_idx
=
0
hs
=
[]
for
_id
,
module
in
enumerate
(
self
.
input_blocks
):
h
=
module
(
h
,
emb
,
context
=
context
,
batch_size
=
b
)
if
_id
==
0
and
self
.
addition_attention
:
h
=
self
.
init_attn
(
h
,
emb
,
context
=
context
,
batch_size
=
b
)
# plug-in adapter features
if
((
_id
+
1
)
%
3
==
0
)
and
features_adapter
is
not
None
:
h
=
h
+
features_adapter
[
adapter_idx
]
adapter_idx
+=
1
hs
.
append
(
h
)
if
return_input_block_features
:
intermediate_features
[
"input"
].
append
(
h
)
if
features_adapter
is
not
None
:
assert
len
(
features_adapter
)
==
adapter_idx
,
"Wrong features_adapter"
h
=
self
.
middle_block
(
h
,
emb
,
context
=
context
,
batch_size
=
b
)
if
return_middle_feature
:
intermediate_features
[
"middle"
]
=
h
if
self
.
use_feature_alignment
:
feature_alignment_output
=
self
.
feature_alignment_adapter
(
hs
[
2
],
hs
[
5
],
hs
[
8
],
emb
=
emb
)
# >>> Output Blocks Forward
if
self
.
use_ray_decoder
:
h_original
=
h
h_ray
=
h
for
original_module
,
ray_module
in
zip
(
self
.
output_blocks
,
self
.
ray_decoder_blocks
):
cur_hs
=
hs
.
pop
()
h_original
=
torch
.
cat
([
h_original
,
cur_hs
],
dim
=
1
)
h_original
=
original_module
(
h_original
,
emb
,
context
=
context
,
batch_size
=
b
,
time_steps
=
time_steps
,
)
if
self
.
use_ray_decoder_residual
:
h_ray
=
torch
.
cat
([
h_ray
,
cur_hs
],
dim
=
1
)
h_ray
=
ray_module
(
h_ray
,
emb
,
context
=
context
,
batch_size
=
b
)
if
return_output_block_features
:
print
(
"return_output_block_features: h_original.shape="
,
h_original
.
shape
,
)
intermediate_features
[
"output"
].
append
(
h_original
.
detach
())
h_original
=
h_original
.
type
(
x
.
dtype
)
h_ray
=
h_ray
.
type
(
x
.
dtype
)
y_original
=
self
.
out
(
h_original
)
y_ray
=
self
.
ray_output_head
(
h_ray
)
y
=
torch
.
cat
([
y_original
,
y_ray
],
dim
=
1
)
else
:
if
self
.
use_lora_for_rays_in_output_blocks
:
middle_h
=
h
h_original
=
middle_h
h_lora
=
middle_h
for
output_idx
,
module
in
enumerate
(
self
.
output_blocks
):
cur_hs
=
hs
.
pop
()
h_original
=
torch
.
cat
([
h_original
,
cur_hs
],
dim
=
1
)
h_original
=
module
(
h_original
,
emb
,
context
=
context
,
batch_size
=
b
,
with_lora
=
False
)
h_lora
=
torch
.
cat
([
h_lora
,
cur_hs
],
dim
=
1
)
h_lora
=
module
(
h_lora
,
emb
,
context
=
context
,
batch_size
=
b
,
with_lora
=
True
)
h_original
=
h_original
.
type
(
x
.
dtype
)
h_lora
=
h_lora
.
type
(
x
.
dtype
)
y_original
=
self
.
out
(
h_original
)
y_lora
=
self
.
ray_output_head
(
h_lora
)
y
=
torch
.
cat
([
y_original
,
y_lora
],
dim
=
1
)
else
:
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
,
context
=
context
,
batch_size
=
b
)
h
=
h
.
type
(
x
.
dtype
)
if
self
.
use_task_embedding
:
# Seperated Input (Branch Control in CPU)
# Serial Execution (GPU Vectorization Pending)
if
task_idx
==
TASK_IDX_IMAGE
:
y
=
self
.
out
(
h
)
elif
task_idx
==
TASK_IDX_RAY
:
y
=
self
.
ray_output_head
(
h
)
else
:
raise
NotImplementedError
(
f
"Unsupported `task_idx`:
{
task_idx
}
"
)
else
:
# Output ray and images at the same forward
y
=
self
.
out
(
h
)
if
self
.
use_addition_ray_output_head
:
y_ray
=
self
.
ray_output_head
(
h
)
y
=
torch
.
cat
([
y
,
y_ray
],
dim
=
1
)
# reshape back to (b c t h w)
y
=
rearrange
(
y
,
"(b t) c h w -> b t c h w"
,
b
=
b
)
if
(
return_input_block_features
or
return_output_block_features
or
return_middle_feature
):
return
y
,
intermediate_features
# Assume intermediate features are only request during non-training scenarios (e.g., feature visualization)
if
self
.
use_feature_alignment
:
return
y
,
feature_alignment_output
return
y
class
FeatureAlignmentAdapter
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
time_embed_dim
,
use_checkpoint
,
dropout
=
0.0
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
channel_adapter_conv_16
=
torch
.
nn
.
Conv2d
(
in_channels
=
1280
,
out_channels
=
320
,
kernel_size
=
1
)
self
.
channel_adapter_conv_32
=
torch
.
nn
.
Conv2d
(
in_channels
=
640
,
out_channels
=
320
,
kernel_size
=
1
)
self
.
upsampler_x2
=
torch
.
nn
.
UpsamplingBilinear2d
(
scale_factor
=
2
)
self
.
upsampler_x4
=
torch
.
nn
.
UpsamplingBilinear2d
(
scale_factor
=
4
)
self
.
res_block
=
ResBlock
(
320
*
3
,
time_embed_dim
,
dropout
,
out_channels
=
32
*
3
,
dims
=
2
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
False
,
)
self
.
final_conv
=
conv_nd
(
dims
=
2
,
in_channels
=
32
*
3
,
out_channels
=
6
,
kernel_size
=
1
)
def
forward
(
self
,
feature_64
,
feature_32
,
feature_16
,
emb
):
feature_16_adapted
=
self
.
channel_adapter_conv_16
(
feature_16
)
feature_32_adapted
=
self
.
channel_adapter_conv_32
(
feature_32
)
feature_16_upsampled
=
self
.
upsampler_x4
(
feature_16_adapted
)
feature_32_upsampled
=
self
.
upsampler_x2
(
feature_32_adapted
)
feature_all
=
torch
.
concat
(
[
feature_16_upsampled
,
feature_32_upsampled
,
feature_64
],
dim
=
1
)
# bt, 3, h, w
return
self
.
final_conv
(
self
.
res_block
(
feature_all
,
emb
=
emb
))
NVComposer/core/modules/position_encoding.py
0 → 100755
View file @
30af93f2
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Various positional encodings for the transformer.
"""
import
math
import
torch
from
torch
import
nn
class
PositionEmbeddingSine
(
nn
.
Module
):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def
__init__
(
self
,
num_pos_feats
=
64
,
temperature
=
10000
,
normalize
=
False
,
scale
=
None
):
super
().
__init__
()
self
.
num_pos_feats
=
num_pos_feats
self
.
temperature
=
temperature
self
.
normalize
=
normalize
if
scale
is
not
None
and
normalize
is
False
:
raise
ValueError
(
"normalize should be True if scale is passed"
)
if
scale
is
None
:
scale
=
2
*
math
.
pi
self
.
scale
=
scale
def
forward
(
self
,
token_tensors
):
# input: (B,C,H,W)
x
=
token_tensors
h
,
w
=
x
.
shape
[
-
2
:]
identity_map
=
torch
.
ones
((
h
,
w
),
device
=
x
.
device
)
y_embed
=
identity_map
.
cumsum
(
0
,
dtype
=
torch
.
float32
)
x_embed
=
identity_map
.
cumsum
(
1
,
dtype
=
torch
.
float32
)
if
self
.
normalize
:
eps
=
1e-6
y_embed
=
y_embed
/
(
y_embed
[
-
1
:,
:]
+
eps
)
*
self
.
scale
x_embed
=
x_embed
/
(
x_embed
[:,
-
1
:]
+
eps
)
*
self
.
scale
dim_t
=
torch
.
arange
(
self
.
num_pos_feats
,
dtype
=
torch
.
float32
,
device
=
x
.
device
)
dim_t
=
self
.
temperature
**
(
2
*
(
dim_t
//
2
)
/
self
.
num_pos_feats
)
pos_x
=
x_embed
[:,
:,
None
]
/
dim_t
pos_y
=
y_embed
[:,
:,
None
]
/
dim_t
pos_x
=
torch
.
stack
(
(
pos_x
[:,
:,
0
::
2
].
sin
(),
pos_x
[:,
:,
1
::
2
].
cos
()),
dim
=
3
).
flatten
(
2
)
pos_y
=
torch
.
stack
(
(
pos_y
[:,
:,
0
::
2
].
sin
(),
pos_y
[:,
:,
1
::
2
].
cos
()),
dim
=
3
).
flatten
(
2
)
pos
=
torch
.
cat
((
pos_y
,
pos_x
),
dim
=
2
).
permute
(
2
,
0
,
1
)
batch_pos
=
pos
.
unsqueeze
(
0
).
repeat
(
x
.
shape
[
0
],
1
,
1
,
1
)
return
batch_pos
class
PositionEmbeddingLearned
(
nn
.
Module
):
"""
Absolute pos embedding, learned.
"""
def
__init__
(
self
,
n_pos_x
=
16
,
n_pos_y
=
16
,
num_pos_feats
=
64
):
super
().
__init__
()
self
.
row_embed
=
nn
.
Embedding
(
n_pos_y
,
num_pos_feats
)
self
.
col_embed
=
nn
.
Embedding
(
n_pos_x
,
num_pos_feats
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
nn
.
init
.
uniform_
(
self
.
row_embed
.
weight
)
nn
.
init
.
uniform_
(
self
.
col_embed
.
weight
)
def
forward
(
self
,
token_tensors
):
# input: (B,C,H,W)
x
=
token_tensors
h
,
w
=
x
.
shape
[
-
2
:]
i
=
torch
.
arange
(
w
,
device
=
x
.
device
)
j
=
torch
.
arange
(
h
,
device
=
x
.
device
)
x_emb
=
self
.
col_embed
(
i
)
y_emb
=
self
.
row_embed
(
j
)
pos
=
torch
.
cat
(
[
x_emb
.
unsqueeze
(
0
).
repeat
(
h
,
1
,
1
),
y_emb
.
unsqueeze
(
1
).
repeat
(
1
,
w
,
1
),
],
dim
=-
1
,
).
permute
(
2
,
0
,
1
)
batch_pos
=
pos
.
unsqueeze
(
0
).
repeat
(
x
.
shape
[
0
],
1
,
1
,
1
)
return
batch_pos
def
build_position_encoding
(
num_pos_feats
=
64
,
n_pos_x
=
16
,
n_pos_y
=
16
,
is_learned
=
False
):
if
is_learned
:
position_embedding
=
PositionEmbeddingLearned
(
n_pos_x
,
n_pos_y
,
num_pos_feats
)
else
:
position_embedding
=
PositionEmbeddingSine
(
num_pos_feats
,
normalize
=
True
)
return
position_embedding
NVComposer/core/modules/x_transformer.py
0 → 100755
View file @
30af93f2
from
functools
import
partial
from
inspect
import
isfunction
from
collections
import
namedtuple
from
einops
import
rearrange
,
repeat
import
torch
from
torch
import
nn
,
einsum
import
torch.nn.functional
as
F
DEFAULT_DIM_HEAD
=
64
Intermediates
=
namedtuple
(
"Intermediates"
,
[
"pre_softmax_attn"
,
"post_softmax_attn"
])
LayerIntermediates
=
namedtuple
(
"Intermediates"
,
[
"hiddens"
,
"attn_intermediates"
])
class
AbsolutePositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_seq_len
):
super
().
__init__
()
self
.
emb
=
nn
.
Embedding
(
max_seq_len
,
dim
)
self
.
init_
()
def
init_
(
self
):
nn
.
init
.
normal_
(
self
.
emb
.
weight
,
std
=
0.02
)
def
forward
(
self
,
x
):
n
=
torch
.
arange
(
x
.
shape
[
1
],
device
=
x
.
device
)
return
self
.
emb
(
n
)[
None
,
:,
:]
class
FixedPositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
inv_freq
=
1.0
/
(
10000
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
def
forward
(
self
,
x
,
seq_dim
=
1
,
offset
=
0
):
t
=
(
torch
.
arange
(
x
.
shape
[
seq_dim
],
device
=
x
.
device
).
type_as
(
self
.
inv_freq
)
+
offset
)
sinusoid_inp
=
torch
.
einsum
(
"i , j -> i j"
,
t
,
self
.
inv_freq
)
emb
=
torch
.
cat
((
sinusoid_inp
.
sin
(),
sinusoid_inp
.
cos
()),
dim
=-
1
)
return
emb
[
None
,
:,
:]
def
exists
(
val
):
return
val
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
always
(
val
):
def
inner
(
*
args
,
**
kwargs
):
return
val
return
inner
def
not_equals
(
val
):
def
inner
(
x
):
return
x
!=
val
return
inner
def
equals
(
val
):
def
inner
(
x
):
return
x
==
val
return
inner
def
max_neg_value
(
tensor
):
return
-
torch
.
finfo
(
tensor
.
dtype
).
max
def
pick_and_pop
(
keys
,
d
):
values
=
list
(
map
(
lambda
key
:
d
.
pop
(
key
),
keys
))
return
dict
(
zip
(
keys
,
values
))
def
group_dict_by_key
(
cond
,
d
):
return_val
=
[
dict
(),
dict
()]
for
key
in
d
.
keys
():
match
=
bool
(
cond
(
key
))
ind
=
int
(
not
match
)
return_val
[
ind
][
key
]
=
d
[
key
]
return
(
*
return_val
,)
def
string_begins_with
(
prefix
,
str
):
return
str
.
startswith
(
prefix
)
def
group_by_key_prefix
(
prefix
,
d
):
return
group_dict_by_key
(
partial
(
string_begins_with
,
prefix
),
d
)
def
groupby_prefix_and_trim
(
prefix
,
d
):
kwargs_with_prefix
,
kwargs
=
group_dict_by_key
(
partial
(
string_begins_with
,
prefix
),
d
)
kwargs_without_prefix
=
dict
(
map
(
lambda
x
:
(
x
[
0
][
len
(
prefix
)
:],
x
[
1
]),
tuple
(
kwargs_with_prefix
.
items
()))
)
return
kwargs_without_prefix
,
kwargs
class
Scale
(
nn
.
Module
):
def
__init__
(
self
,
value
,
fn
):
super
().
__init__
()
self
.
value
=
value
self
.
fn
=
fn
def
forward
(
self
,
x
,
**
kwargs
):
x
,
*
rest
=
self
.
fn
(
x
,
**
kwargs
)
return
(
x
*
self
.
value
,
*
rest
)
class
Rezero
(
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
().
__init__
()
self
.
fn
=
fn
self
.
g
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
x
,
**
kwargs
):
x
,
*
rest
=
self
.
fn
(
x
,
**
kwargs
)
return
(
x
*
self
.
g
,
*
rest
)
class
ScaleNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-5
):
super
().
__init__
()
self
.
scale
=
dim
**-
0.5
self
.
eps
=
eps
self
.
g
=
nn
.
Parameter
(
torch
.
ones
(
1
))
def
forward
(
self
,
x
):
norm
=
torch
.
norm
(
x
,
dim
=-
1
,
keepdim
=
True
)
*
self
.
scale
return
x
/
norm
.
clamp
(
min
=
self
.
eps
)
*
self
.
g
class
RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-8
):
super
().
__init__
()
self
.
scale
=
dim
**-
0.5
self
.
eps
=
eps
self
.
g
=
nn
.
Parameter
(
torch
.
ones
(
dim
))
def
forward
(
self
,
x
):
norm
=
torch
.
norm
(
x
,
dim
=-
1
,
keepdim
=
True
)
*
self
.
scale
return
x
/
norm
.
clamp
(
min
=
self
.
eps
)
*
self
.
g
class
Residual
(
nn
.
Module
):
def
forward
(
self
,
x
,
residual
):
return
x
+
residual
class
GRUGating
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
gru
=
nn
.
GRUCell
(
dim
,
dim
)
def
forward
(
self
,
x
,
residual
):
gated_output
=
self
.
gru
(
rearrange
(
x
,
"b n d -> (b n) d"
),
rearrange
(
residual
,
"b n d -> (b n) d"
)
)
return
gated_output
.
reshape_as
(
x
)
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.0
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
(
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
())
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_head
=
DEFAULT_DIM_HEAD
,
heads
=
8
,
causal
=
False
,
mask
=
None
,
talking_heads
=
False
,
sparse_topk
=
None
,
use_entmax15
=
False
,
num_mem_kv
=
0
,
dropout
=
0.0
,
on_attn
=
False
,
):
super
().
__init__
()
if
use_entmax15
:
raise
NotImplementedError
(
"Check out entmax activation instead of softmax activation!"
)
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
self
.
causal
=
causal
self
.
mask
=
mask
inner_dim
=
dim_head
*
heads
self
.
to_q
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
talking_heads
=
talking_heads
if
talking_heads
:
self
.
pre_softmax_proj
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
heads
))
self
.
post_softmax_proj
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
heads
))
self
.
sparse_topk
=
sparse_topk
self
.
attn_fn
=
F
.
softmax
self
.
num_mem_kv
=
num_mem_kv
if
num_mem_kv
>
0
:
self
.
mem_k
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
num_mem_kv
,
dim_head
))
self
.
mem_v
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
num_mem_kv
,
dim_head
))
self
.
attn_on_attn
=
on_attn
self
.
to_out
=
(
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
dim
*
2
),
nn
.
GLU
())
if
on_attn
else
nn
.
Linear
(
inner_dim
,
dim
)
)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
context_mask
=
None
,
rel_pos
=
None
,
sinusoidal_emb
=
None
,
prev_attn
=
None
,
mem
=
None
,
):
b
,
n
,
_
,
h
,
talking_heads
,
device
=
(
*
x
.
shape
,
self
.
heads
,
self
.
talking_heads
,
x
.
device
,
)
kv_input
=
default
(
context
,
x
)
q_input
=
x
k_input
=
kv_input
v_input
=
kv_input
if
exists
(
mem
):
k_input
=
torch
.
cat
((
mem
,
k_input
),
dim
=-
2
)
v_input
=
torch
.
cat
((
mem
,
v_input
),
dim
=-
2
)
if
exists
(
sinusoidal_emb
):
offset
=
k_input
.
shape
[
-
2
]
-
q_input
.
shape
[
-
2
]
q_input
=
q_input
+
sinusoidal_emb
(
q_input
,
offset
=
offset
)
k_input
=
k_input
+
sinusoidal_emb
(
k_input
)
q
=
self
.
to_q
(
q_input
)
k
=
self
.
to_k
(
k_input
)
v
=
self
.
to_v
(
v_input
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
"b n (h d) -> b h n d"
,
h
=
h
),
(
q
,
k
,
v
))
input_mask
=
None
if
any
(
map
(
exists
,
(
mask
,
context_mask
))):
q_mask
=
default
(
mask
,
lambda
:
torch
.
ones
((
b
,
n
),
device
=
device
).
bool
())
k_mask
=
q_mask
if
not
exists
(
context
)
else
context_mask
k_mask
=
default
(
k_mask
,
lambda
:
torch
.
ones
((
b
,
k
.
shape
[
-
2
]),
device
=
device
).
bool
()
)
q_mask
=
rearrange
(
q_mask
,
"b i -> b () i ()"
)
k_mask
=
rearrange
(
k_mask
,
"b j -> b () () j"
)
input_mask
=
q_mask
*
k_mask
if
self
.
num_mem_kv
>
0
:
mem_k
,
mem_v
=
map
(
lambda
t
:
repeat
(
t
,
"h n d -> b h n d"
,
b
=
b
),
(
self
.
mem_k
,
self
.
mem_v
)
)
k
=
torch
.
cat
((
mem_k
,
k
),
dim
=-
2
)
v
=
torch
.
cat
((
mem_v
,
v
),
dim
=-
2
)
if
exists
(
input_mask
):
input_mask
=
F
.
pad
(
input_mask
,
(
self
.
num_mem_kv
,
0
),
value
=
True
)
dots
=
einsum
(
"b h i d, b h j d -> b h i j"
,
q
,
k
)
*
self
.
scale
mask_value
=
max_neg_value
(
dots
)
if
exists
(
prev_attn
):
dots
=
dots
+
prev_attn
pre_softmax_attn
=
dots
if
talking_heads
:
dots
=
einsum
(
"b h i j, h k -> b k i j"
,
dots
,
self
.
pre_softmax_proj
).
contiguous
()
if
exists
(
rel_pos
):
dots
=
rel_pos
(
dots
)
if
exists
(
input_mask
):
dots
.
masked_fill_
(
~
input_mask
,
mask_value
)
del
input_mask
if
self
.
causal
:
i
,
j
=
dots
.
shape
[
-
2
:]
r
=
torch
.
arange
(
i
,
device
=
device
)
mask
=
rearrange
(
r
,
"i -> () () i ()"
)
<
rearrange
(
r
,
"j -> () () () j"
)
mask
=
F
.
pad
(
mask
,
(
j
-
i
,
0
),
value
=
False
)
dots
.
masked_fill_
(
mask
,
mask_value
)
del
mask
if
exists
(
self
.
sparse_topk
)
and
self
.
sparse_topk
<
dots
.
shape
[
-
1
]:
top
,
_
=
dots
.
topk
(
self
.
sparse_topk
,
dim
=-
1
)
vk
=
top
[...,
-
1
].
unsqueeze
(
-
1
).
expand_as
(
dots
)
mask
=
dots
<
vk
dots
.
masked_fill_
(
mask
,
mask_value
)
del
mask
attn
=
self
.
attn_fn
(
dots
,
dim
=-
1
)
post_softmax_attn
=
attn
attn
=
self
.
dropout
(
attn
)
if
talking_heads
:
attn
=
einsum
(
"b h i j, h k -> b k i j"
,
attn
,
self
.
post_softmax_proj
).
contiguous
()
out
=
einsum
(
"b h i j, b h j d -> b h i d"
,
attn
,
v
)
out
=
rearrange
(
out
,
"b h n d -> b n (h d)"
)
intermediates
=
Intermediates
(
pre_softmax_attn
=
pre_softmax_attn
,
post_softmax_attn
=
post_softmax_attn
)
return
self
.
to_out
(
out
),
intermediates
class
AttentionLayers
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
depth
,
heads
=
8
,
causal
=
False
,
cross_attend
=
False
,
only_cross
=
False
,
use_scalenorm
=
False
,
use_rmsnorm
=
False
,
use_rezero
=
False
,
rel_pos_num_buckets
=
32
,
rel_pos_max_distance
=
128
,
position_infused_attn
=
False
,
custom_layers
=
None
,
sandwich_coef
=
None
,
par_ratio
=
None
,
residual_attn
=
False
,
cross_residual_attn
=
False
,
macaron
=
False
,
pre_norm
=
True
,
gate_residual
=
False
,
**
kwargs
,
):
super
().
__init__
()
ff_kwargs
,
kwargs
=
groupby_prefix_and_trim
(
"ff_"
,
kwargs
)
attn_kwargs
,
_
=
groupby_prefix_and_trim
(
"attn_"
,
kwargs
)
dim_head
=
attn_kwargs
.
get
(
"dim_head"
,
DEFAULT_DIM_HEAD
)
self
.
dim
=
dim
self
.
depth
=
depth
self
.
layers
=
nn
.
ModuleList
([])
self
.
has_pos_emb
=
position_infused_attn
self
.
pia_pos_emb
=
(
FixedPositionalEmbedding
(
dim
)
if
position_infused_attn
else
None
)
self
.
rotary_pos_emb
=
always
(
None
)
assert
(
rel_pos_num_buckets
<=
rel_pos_max_distance
),
"number of relative position buckets must be less than the relative position max distance"
self
.
rel_pos
=
None
self
.
pre_norm
=
pre_norm
self
.
residual_attn
=
residual_attn
self
.
cross_residual_attn
=
cross_residual_attn
norm_class
=
ScaleNorm
if
use_scalenorm
else
nn
.
LayerNorm
norm_class
=
RMSNorm
if
use_rmsnorm
else
norm_class
norm_fn
=
partial
(
norm_class
,
dim
)
norm_fn
=
nn
.
Identity
if
use_rezero
else
norm_fn
branch_fn
=
Rezero
if
use_rezero
else
None
if
cross_attend
and
not
only_cross
:
default_block
=
(
"a"
,
"c"
,
"f"
)
elif
cross_attend
and
only_cross
:
default_block
=
(
"c"
,
"f"
)
else
:
default_block
=
(
"a"
,
"f"
)
if
macaron
:
default_block
=
(
"f"
,)
+
default_block
if
exists
(
custom_layers
):
layer_types
=
custom_layers
elif
exists
(
par_ratio
):
par_depth
=
depth
*
len
(
default_block
)
assert
1
<
par_ratio
<=
par_depth
,
"par ratio out of range"
default_block
=
tuple
(
filter
(
not_equals
(
"f"
),
default_block
))
par_attn
=
par_depth
//
par_ratio
depth_cut
=
par_depth
*
2
//
3
par_width
=
(
depth_cut
+
depth_cut
//
par_attn
)
//
par_attn
assert
(
len
(
default_block
)
<=
par_width
),
"default block is too large for par_ratio"
par_block
=
default_block
+
(
"f"
,)
*
(
par_width
-
len
(
default_block
))
par_head
=
par_block
*
par_attn
layer_types
=
par_head
+
(
"f"
,)
*
(
par_depth
-
len
(
par_head
))
elif
exists
(
sandwich_coef
):
assert
(
sandwich_coef
>
0
and
sandwich_coef
<=
depth
),
"sandwich coefficient should be less than the depth"
layer_types
=
(
(
"a"
,)
*
sandwich_coef
+
default_block
*
(
depth
-
sandwich_coef
)
+
(
"f"
,)
*
sandwich_coef
)
else
:
layer_types
=
default_block
*
depth
self
.
layer_types
=
layer_types
self
.
num_attn_layers
=
len
(
list
(
filter
(
equals
(
"a"
),
layer_types
)))
for
layer_type
in
self
.
layer_types
:
if
layer_type
==
"a"
:
layer
=
Attention
(
dim
,
heads
=
heads
,
causal
=
causal
,
**
attn_kwargs
)
elif
layer_type
==
"c"
:
layer
=
Attention
(
dim
,
heads
=
heads
,
**
attn_kwargs
)
elif
layer_type
==
"f"
:
layer
=
FeedForward
(
dim
,
**
ff_kwargs
)
layer
=
layer
if
not
macaron
else
Scale
(
0.5
,
layer
)
else
:
raise
Exception
(
f
"invalid layer type
{
layer_type
}
"
)
if
isinstance
(
layer
,
Attention
)
and
exists
(
branch_fn
):
layer
=
branch_fn
(
layer
)
if
gate_residual
:
residual_fn
=
GRUGating
(
dim
)
else
:
residual_fn
=
Residual
()
self
.
layers
.
append
(
nn
.
ModuleList
([
norm_fn
(),
layer
,
residual_fn
]))
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
context_mask
=
None
,
mems
=
None
,
return_hiddens
=
False
,
):
hiddens
=
[]
intermediates
=
[]
prev_attn
=
None
prev_cross_attn
=
None
mems
=
mems
.
copy
()
if
exists
(
mems
)
else
[
None
]
*
self
.
num_attn_layers
for
ind
,
(
layer_type
,
(
norm
,
block
,
residual_fn
))
in
enumerate
(
zip
(
self
.
layer_types
,
self
.
layers
)
):
is_last
=
ind
==
(
len
(
self
.
layers
)
-
1
)
if
layer_type
==
"a"
:
hiddens
.
append
(
x
)
layer_mem
=
mems
.
pop
(
0
)
residual
=
x
if
self
.
pre_norm
:
x
=
norm
(
x
)
if
layer_type
==
"a"
:
out
,
inter
=
block
(
x
,
mask
=
mask
,
sinusoidal_emb
=
self
.
pia_pos_emb
,
rel_pos
=
self
.
rel_pos
,
prev_attn
=
prev_attn
,
mem
=
layer_mem
,
)
elif
layer_type
==
"c"
:
out
,
inter
=
block
(
x
,
context
=
context
,
mask
=
mask
,
context_mask
=
context_mask
,
prev_attn
=
prev_cross_attn
,
)
elif
layer_type
==
"f"
:
out
=
block
(
x
)
x
=
residual_fn
(
out
,
residual
)
if
layer_type
in
(
"a"
,
"c"
):
intermediates
.
append
(
inter
)
if
layer_type
==
"a"
and
self
.
residual_attn
:
prev_attn
=
inter
.
pre_softmax_attn
elif
layer_type
==
"c"
and
self
.
cross_residual_attn
:
prev_cross_attn
=
inter
.
pre_softmax_attn
if
not
self
.
pre_norm
and
not
is_last
:
x
=
norm
(
x
)
if
return_hiddens
:
intermediates
=
LayerIntermediates
(
hiddens
=
hiddens
,
attn_intermediates
=
intermediates
)
return
x
,
intermediates
return
x
class
Encoder
(
AttentionLayers
):
def
__init__
(
self
,
**
kwargs
):
assert
"causal"
not
in
kwargs
,
"cannot set causality on encoder"
super
().
__init__
(
causal
=
False
,
**
kwargs
)
class
TransformerWrapper
(
nn
.
Module
):
def
__init__
(
self
,
*
,
num_tokens
,
max_seq_len
,
attn_layers
,
emb_dim
=
None
,
max_mem_len
=
0.0
,
emb_dropout
=
0.0
,
num_memory_tokens
=
None
,
tie_embedding
=
False
,
use_pos_emb
=
True
,
):
super
().
__init__
()
assert
isinstance
(
attn_layers
,
AttentionLayers
),
"attention layers must be one of Encoder or Decoder"
dim
=
attn_layers
.
dim
emb_dim
=
default
(
emb_dim
,
dim
)
self
.
max_seq_len
=
max_seq_len
self
.
max_mem_len
=
max_mem_len
self
.
num_tokens
=
num_tokens
self
.
token_emb
=
nn
.
Embedding
(
num_tokens
,
emb_dim
)
self
.
pos_emb
=
(
AbsolutePositionalEmbedding
(
emb_dim
,
max_seq_len
)
if
(
use_pos_emb
and
not
attn_layers
.
has_pos_emb
)
else
always
(
0
)
)
self
.
emb_dropout
=
nn
.
Dropout
(
emb_dropout
)
self
.
project_emb
=
nn
.
Linear
(
emb_dim
,
dim
)
if
emb_dim
!=
dim
else
nn
.
Identity
()
self
.
attn_layers
=
attn_layers
self
.
norm
=
nn
.
LayerNorm
(
dim
)
self
.
init_
()
self
.
to_logits
=
(
nn
.
Linear
(
dim
,
num_tokens
)
if
not
tie_embedding
else
lambda
t
:
t
@
self
.
token_emb
.
weight
.
t
()
)
num_memory_tokens
=
default
(
num_memory_tokens
,
0
)
self
.
num_memory_tokens
=
num_memory_tokens
if
num_memory_tokens
>
0
:
self
.
memory_tokens
=
nn
.
Parameter
(
torch
.
randn
(
num_memory_tokens
,
dim
))
if
hasattr
(
attn_layers
,
"num_memory_tokens"
):
attn_layers
.
num_memory_tokens
=
num_memory_tokens
def
init_
(
self
):
nn
.
init
.
normal_
(
self
.
token_emb
.
weight
,
std
=
0.02
)
def
forward
(
self
,
x
,
return_embeddings
=
False
,
mask
=
None
,
return_mems
=
False
,
return_attn
=
False
,
mems
=
None
,
**
kwargs
,
):
b
,
n
,
device
,
num_mem
=
*
x
.
shape
,
x
.
device
,
self
.
num_memory_tokens
x
=
self
.
token_emb
(
x
)
x
+=
self
.
pos_emb
(
x
)
x
=
self
.
emb_dropout
(
x
)
x
=
self
.
project_emb
(
x
)
if
num_mem
>
0
:
mem
=
repeat
(
self
.
memory_tokens
,
"n d -> b n d"
,
b
=
b
)
x
=
torch
.
cat
((
mem
,
x
),
dim
=
1
)
# auto-handle masking after appending memory tokens
if
exists
(
mask
):
mask
=
F
.
pad
(
mask
,
(
num_mem
,
0
),
value
=
True
)
x
,
intermediates
=
self
.
attn_layers
(
x
,
mask
=
mask
,
mems
=
mems
,
return_hiddens
=
True
,
**
kwargs
)
x
=
self
.
norm
(
x
)
mem
,
x
=
x
[:,
:
num_mem
],
x
[:,
num_mem
:]
out
=
self
.
to_logits
(
x
)
if
not
return_embeddings
else
x
if
return_mems
:
hiddens
=
intermediates
.
hiddens
new_mems
=
(
list
(
map
(
lambda
pair
:
torch
.
cat
(
pair
,
dim
=-
2
),
zip
(
mems
,
hiddens
)))
if
exists
(
mems
)
else
hiddens
)
new_mems
=
list
(
map
(
lambda
t
:
t
[...,
-
self
.
max_mem_len
:,
:].
detach
(),
new_mems
)
)
return
out
,
new_mems
if
return_attn
:
attn_maps
=
list
(
map
(
lambda
t
:
t
.
post_softmax_attn
,
intermediates
.
attn_intermediates
)
)
return
out
,
attn_maps
return
out
NVComposer/main/evaluation/funcs.py
0 → 100755
View file @
30af93f2
from
core.models.samplers.ddim
import
DDIMSampler
import
glob
import
json
import
os
import
sys
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
import
torchvision
from
PIL
import
Image
sys
.
path
.
insert
(
1
,
os
.
path
.
join
(
sys
.
path
[
0
],
".."
,
".."
))
def
batch_ddim_sampling
(
model
,
cond
,
noise_shape
,
n_samples
=
1
,
ddim_steps
=
50
,
ddim_eta
=
1.0
,
cfg_scale
=
1.0
,
temporal_cfg_scale
=
None
,
use_cat_ucg
=
False
,
**
kwargs
,
):
ddim_sampler
=
DDIMSampler
(
model
)
uncond_type
=
model
.
uncond_type
batch_size
=
noise_shape
[
0
]
# construct unconditional guidance
if
cfg_scale
!=
1.0
:
if
uncond_type
==
"empty_seq"
:
prompts
=
batch_size
*
[
""
]
# prompts = N * T * [""] # if is_image_batch=True
uc_emb
=
model
.
get_learned_conditioning
(
prompts
)
elif
uncond_type
==
"zero_embed"
:
c_emb
=
cond
[
"c_crossattn"
][
0
]
if
isinstance
(
cond
,
dict
)
else
cond
uc_emb
=
torch
.
zeros_like
(
c_emb
)
# process image condition
if
hasattr
(
model
,
"embedder"
):
uc_img
=
torch
.
zeros
(
noise_shape
[
0
],
3
,
224
,
224
).
to
(
model
.
device
)
# img: b c h w >> b l c
uc_img
=
model
.
get_image_embeds
(
uc_img
)
uc_emb
=
torch
.
cat
([
uc_emb
,
uc_img
],
dim
=
1
)
if
isinstance
(
cond
,
dict
):
uc
=
{
key
:
cond
[
key
]
for
key
in
cond
.
keys
()}
uc
.
update
({
"c_crossattn"
:
[
uc_emb
]})
# special CFG for frame concatenation
if
use_cat_ucg
and
hasattr
(
model
,
"cond_concat"
)
and
model
.
cond_concat
:
uc_cat
=
torch
.
zeros
(
noise_shape
[
0
],
model
.
cond_channels
,
*
noise_shape
[
2
:]
).
to
(
model
.
device
)
uc
.
update
({
"c_concat"
:
[
uc_cat
]})
else
:
uc
=
[
uc_emb
]
else
:
uc
=
None
# uc.update({'fps': torch.tensor([-4]*batch_size).to(model.device).long()})
# sampling
noise
=
torch
.
randn
(
noise_shape
,
device
=
model
.
device
)
# x_T = repeat(noise[:,:,:1,:,:], 'b c l h w -> b c (l t) h w', t=noise_shape[2])
# x_T = 0.2 * x_T + 0.8 * torch.randn(noise_shape, device=model.device)
x_T
=
None
batch_variants
=
[]
# batch_variants1, batch_variants2 = [], []
for
_
in
range
(
n_samples
):
if
ddim_sampler
is
not
None
:
samples
,
_
=
ddim_sampler
.
sample
(
S
=
ddim_steps
,
conditioning
=
cond
,
batch_size
=
noise_shape
[
0
],
shape
=
noise_shape
[
1
:],
verbose
=
False
,
unconditional_guidance_scale
=
cfg_scale
,
unconditional_conditioning
=
uc
,
eta
=
ddim_eta
,
temporal_length
=
noise_shape
[
2
],
conditional_guidance_scale_temporal
=
temporal_cfg_scale
,
x_T
=
x_T
,
**
kwargs
,
)
# reconstruct from latent to pixel space
batch_images
=
model
.
decode_first_stage
(
samples
)
batch_variants
.
append
(
batch_images
)
"""
pred_x0_list, x_iter_list = _['pred_x0'], _['x_inter']
steps = [0, 15, 25, 30, 35, 40, 43, 46, 49, 50]
for nn in steps:
pred_x0 = pred_x0_list[nn]
x_iter = x_iter_list[nn]
batch_images_x0 = model.decode_first_stage(pred_x0)
batch_variants1.append(batch_images_x0)
batch_images_xt = model.decode_first_stage(x_iter)
batch_variants2.append(batch_images_xt)
"""
# batch, <samples>, c, t, h, w
batch_variants
=
torch
.
stack
(
batch_variants
,
dim
=
1
)
# batch_variants1 = torch.stack(batch_variants1, dim=1)
# batch_variants2 = torch.stack(batch_variants2, dim=1)
# return batch_variants1, batch_variants2
return
batch_variants
def
batch_sliding_interpolation
(
model
,
cond
,
base_videos
,
base_stride
,
noise_shape
,
n_samples
=
1
,
ddim_steps
=
50
,
ddim_eta
=
1.0
,
cfg_scale
=
1.0
,
temporal_cfg_scale
=
None
,
**
kwargs
,
):
"""
Current implementation has a flaw: the inter-episode keyframe is used as pre-last and cur-first, so keyframe repeated.
For example, cond_frames=[0,4,7], model.temporal_length=8, base_stride=4, then
base frame : 0 4 8 12 16 20 24 28
interplation: (0~7) (8~15) (16~23) (20~27)
"""
b
,
c
,
t
,
h
,
w
=
noise_shape
base_z0
=
model
.
encode_first_stage
(
base_videos
)
unit_length
=
model
.
temporal_length
n_base_frames
=
base_videos
.
shape
[
2
]
n_refs
=
len
(
model
.
cond_frames
)
sliding_steps
=
(
n_base_frames
-
1
)
//
(
n_refs
-
1
)
sliding_steps
=
(
sliding_steps
+
1
if
(
n_base_frames
-
1
)
%
(
n_refs
-
1
)
>
0
else
sliding_steps
)
cond_mask
=
model
.
cond_mask
.
to
(
"cuda"
)
proxy_z0
=
torch
.
zeros
((
b
,
c
,
unit_length
,
h
,
w
),
dtype
=
torch
.
float32
).
to
(
"cuda"
)
batch_samples
=
None
last_offset
=
None
for
idx
in
range
(
sliding_steps
):
base_idx
=
idx
*
(
n_refs
-
1
)
# check index overflow
if
base_idx
+
n_refs
>
n_base_frames
:
last_offset
=
base_idx
-
(
n_base_frames
-
n_refs
)
base_idx
=
n_base_frames
-
n_refs
cond_z0
=
base_z0
[:,
:,
base_idx
:
base_idx
+
n_refs
,
:,
:]
proxy_z0
[:,
:,
model
.
cond_frames
,
:,
:]
=
cond_z0
if
"c_concat"
in
cond
:
c_cat
,
text_emb
=
cond
[
"c_concat"
][
0
],
cond
[
"c_crossattn"
][
0
]
episode_idx
=
idx
*
unit_length
if
last_offset
is
not
None
:
episode_idx
=
episode_idx
-
last_offset
*
base_stride
cond_idx
=
{
"c_concat"
:
[
c_cat
[:,
:,
episode_idx
:
episode_idx
+
unit_length
,
:,
:]
],
"c_crossattn"
:
[
text_emb
],
}
else
:
cond_idx
=
cond
noise_shape_idx
=
[
b
,
c
,
unit_length
,
h
,
w
]
# batch, <samples>, c, t, h, w
batch_idx
=
batch_ddim_sampling
(
model
,
cond_idx
,
noise_shape_idx
,
n_samples
,
ddim_steps
,
ddim_eta
,
cfg_scale
,
temporal_cfg_scale
,
mask
=
cond_mask
,
x0
=
proxy_z0
,
**
kwargs
,
)
if
batch_samples
is
None
:
batch_samples
=
batch_idx
else
:
# b,s,c,t,h,w
if
last_offset
is
None
:
batch_samples
=
torch
.
cat
(
[
batch_samples
[:,
:,
:,
:
-
1
,
:,
:],
batch_idx
],
dim
=
3
)
else
:
batch_samples
=
torch
.
cat
(
[
batch_samples
[:,
:,
:,
:
-
1
,
:,
:],
batch_idx
[:,
:,
:,
last_offset
*
base_stride
:,
:,
:],
],
dim
=
3
,
)
return
batch_samples
def
get_filelist
(
data_dir
,
ext
=
"*"
):
file_list
=
glob
.
glob
(
os
.
path
.
join
(
data_dir
,
"*.%s"
%
ext
))
file_list
.
sort
()
return
file_list
def
get_dirlist
(
path
):
list
=
[]
if
os
.
path
.
exists
(
path
):
files
=
os
.
listdir
(
path
)
for
file
in
files
:
m
=
os
.
path
.
join
(
path
,
file
)
if
os
.
path
.
isdir
(
m
):
list
.
append
(
m
)
list
.
sort
()
return
list
def
load_model_checkpoint
(
model
,
ckpt
,
adapter_ckpt
=
None
):
def
load_checkpoint
(
model
,
ckpt
,
full_strict
):
state_dict
=
torch
.
load
(
ckpt
,
map_location
=
"cpu"
,
weights_only
=
True
)
try
:
# deepspeed
new_pl_sd
=
OrderedDict
()
for
key
in
state_dict
[
"module"
].
keys
():
new_pl_sd
[
key
[
16
:]]
=
state_dict
[
"module"
][
key
]
model
.
load_state_dict
(
new_pl_sd
,
strict
=
full_strict
)
except
:
if
"state_dict"
in
list
(
state_dict
.
keys
()):
state_dict
=
state_dict
[
"state_dict"
]
model
.
load_state_dict
(
state_dict
,
strict
=
full_strict
)
return
model
if
adapter_ckpt
:
# main model
load_checkpoint
(
model
,
ckpt
,
full_strict
=
False
)
print
(
">>> model checkpoint loaded."
)
# adapter
state_dict
=
torch
.
load
(
adapter_ckpt
,
map_location
=
"cpu"
)
if
"state_dict"
in
list
(
state_dict
.
keys
()):
state_dict
=
state_dict
[
"state_dict"
]
model
.
adapter
.
load_state_dict
(
state_dict
,
strict
=
True
)
print
(
">>> adapter checkpoint loaded."
)
else
:
load_checkpoint
(
model
,
ckpt
,
full_strict
=
False
)
print
(
">>> model checkpoint loaded."
)
return
model
def
load_prompts
(
prompt_file
):
f
=
open
(
prompt_file
,
"r"
)
prompt_list
=
[]
for
idx
,
line
in
enumerate
(
f
.
readlines
()):
l
=
line
.
strip
()
if
len
(
l
)
!=
0
:
prompt_list
.
append
(
l
)
f
.
close
()
return
prompt_list
def
load_camera_poses
(
filepath_list
,
video_frames
=
16
):
pose_list
=
[]
for
filepath
in
filepath_list
:
with
open
(
filepath
,
"r"
)
as
f
:
pose
=
json
.
load
(
f
)
pose
=
np
.
array
(
pose
)
# [t, 12]
pose
=
torch
.
tensor
(
pose
).
float
()
# [t, 12]
assert
(
pose
.
shape
[
0
]
==
video_frames
),
f
"conditional pose frames Not matching the target frames [
{
video_frames
}
]."
pose_list
.
append
(
pose
)
batch_poses
=
torch
.
stack
(
pose_list
,
dim
=
0
)
# shape [b,t,12,1]
return
batch_poses
[...,
None
]
def
save_videos
(
batch_tensors
:
torch
.
Tensor
,
save_dir
:
str
,
filenames
:
list
[
str
],
fps
:
int
=
10
):
# b,samples,t,c,h,w
n_samples
=
batch_tensors
.
shape
[
1
]
for
idx
,
vid_tensor
in
enumerate
(
batch_tensors
):
video
=
vid_tensor
.
detach
().
cpu
()
video
=
torch
.
clamp
(
video
.
float
(),
-
1.0
,
1.0
)
video
=
video
.
permute
(
1
,
0
,
2
,
3
,
4
)
# t,n,c,h,w
frame_grids
=
[
torchvision
.
utils
.
make_grid
(
framesheet
,
nrow
=
int
(
n_samples
))
for
framesheet
in
video
]
# [3, 1*h, n*w]
# stack in temporal dim [t, 3, n*h, w]
grid
=
torch
.
stack
(
frame_grids
,
dim
=
0
)
grid
=
(
grid
+
1.0
)
/
2.0
grid
=
(
grid
*
255
).
to
(
torch
.
uint8
).
permute
(
0
,
2
,
3
,
1
)
savepath
=
os
.
path
.
join
(
save_dir
,
f
"
{
filenames
[
idx
]
}
.mp4"
)
torchvision
.
io
.
write_video
(
savepath
,
grid
,
fps
=
fps
,
video_codec
=
"h264"
,
options
=
{
"crf"
:
"10"
}
)
NVComposer/main/evaluation/pose_interpolation.py
0 → 100755
View file @
30af93f2
import
torch
import
math
def
slerp
(
R1
,
R2
,
alpha
):
"""
Perform Spherical Linear Interpolation (SLERP) between two rotation matrices.
R1, R2: (3x3) rotation matrices.
alpha: interpolation factor, ranging from 0 to 1.
"""
# Convert the rotation matrices to quaternions
def
rotation_matrix_to_quaternion
(
R
):
w
=
torch
.
sqrt
(
1.0
+
R
[
0
,
0
]
+
R
[
1
,
1
]
+
R
[
2
,
2
])
/
2.0
w4
=
4.0
*
w
x
=
(
R
[
2
,
1
]
-
R
[
1
,
2
])
/
w4
y
=
(
R
[
0
,
2
]
-
R
[
2
,
0
])
/
w4
z
=
(
R
[
1
,
0
]
-
R
[
0
,
1
])
/
w4
return
torch
.
tensor
([
w
,
x
,
y
,
z
]).
float
()
def
quaternion_to_rotation_matrix
(
q
):
w
,
x
,
y
,
z
=
q
return
torch
.
tensor
(
[
[
1
-
2
*
y
*
y
-
2
*
z
*
z
,
2
*
x
*
y
-
2
*
w
*
z
,
2
*
x
*
z
+
2
*
w
*
y
,
],
[
2
*
x
*
y
+
2
*
w
*
z
,
1
-
2
*
x
*
x
-
2
*
z
*
z
,
2
*
y
*
z
-
2
*
w
*
x
,
],
[
2
*
x
*
z
-
2
*
w
*
y
,
2
*
y
*
z
+
2
*
w
*
x
,
1
-
2
*
x
*
x
-
2
*
y
*
y
,
],
]
).
float
()
q1
=
rotation_matrix_to_quaternion
(
R1
)
q2
=
rotation_matrix_to_quaternion
(
R2
)
# Dot product of the quaternions
dot
=
torch
.
dot
(
q1
,
q2
)
# If the dot product is negative, negate one quaternion to ensure the shortest path is taken
if
dot
<
0.0
:
q2
=
-
q2
dot
=
-
dot
# SLERP formula
if
(
dot
>
0.9995
):
# If the quaternions are nearly identical, use linear interpolation
q_interp
=
(
1
-
alpha
)
*
q1
+
alpha
*
q2
else
:
theta_0
=
torch
.
acos
(
dot
)
# Angle between q1 and q2
sin_theta_0
=
torch
.
sin
(
theta_0
)
theta
=
theta_0
*
alpha
# Angle between q1 and interpolated quaternion
sin_theta
=
torch
.
sin
(
theta
)
s1
=
torch
.
sin
((
1
-
alpha
)
*
theta_0
)
/
sin_theta_0
s2
=
sin_theta
/
sin_theta_0
q_interp
=
s1
*
q1
+
s2
*
q2
# Convert the interpolated quaternion back to a rotation matrix
R_interp
=
quaternion_to_rotation_matrix
(
q_interp
)
return
R_interp
def
interpolate_camera_poses
(
pose1
,
pose2
,
num_steps
):
"""
Interpolate between two camera poses (3x4 matrices) over a number of steps.
pose1, pose2: (3x4) camera pose matrices (R|t), where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
num_steps: number of interpolation steps.
Returns:
A list of interpolated poses as (3x4) matrices.
"""
R1
,
t1
=
pose1
[:,
:
3
],
pose1
[:,
3
]
R2
,
t2
=
pose2
[:,
:
3
],
pose2
[:,
3
]
interpolated_poses
=
[]
for
i
in
range
(
num_steps
):
alpha
=
i
/
(
num_steps
-
1
)
# Interpolation factor ranging from 0 to 1
# Interpolate rotation using SLERP
R_interp
=
slerp
(
R1
,
R2
,
alpha
)
# Interpolate translation using linear interpolation (LERP)
t_interp
=
(
1
-
alpha
)
*
t1
+
alpha
*
t2
# Combine interpolated rotation and translation into a (3x4) pose matrix
pose_interp
=
torch
.
cat
([
R_interp
,
t_interp
.
unsqueeze
(
1
)],
dim
=
1
)
interpolated_poses
.
append
(
pose_interp
)
return
interpolated_poses
def
rotation_matrix_from_xyz_angles
(
x_angle
,
y_angle
,
z_angle
):
"""
Compute the rotation matrix from given x, y, z angles (in radians).
x_angle: Rotation around the x-axis (pitch).
y_angle: Rotation around the y-axis (yaw).
z_angle: Rotation around the z-axis (roll).
Returns:
A 3x3 rotation matrix.
"""
# Rotation matrices around each axis
Rx
=
torch
.
tensor
(
[
[
1
,
0
,
0
],
[
0
,
torch
.
cos
(
x_angle
),
-
torch
.
sin
(
x_angle
)],
[
0
,
torch
.
sin
(
x_angle
),
torch
.
cos
(
x_angle
)],
]
).
float
()
Ry
=
torch
.
tensor
(
[
[
torch
.
cos
(
y_angle
),
0
,
torch
.
sin
(
y_angle
)],
[
0
,
1
,
0
],
[
-
torch
.
sin
(
y_angle
),
0
,
torch
.
cos
(
y_angle
)],
]
).
float
()
Rz
=
torch
.
tensor
(
[
[
torch
.
cos
(
z_angle
),
-
torch
.
sin
(
z_angle
),
0
],
[
torch
.
sin
(
z_angle
),
torch
.
cos
(
z_angle
),
0
],
[
0
,
0
,
1
],
]
).
float
()
# Combined rotation matrix R = Rz * Ry * Rx
R_combined
=
Rz
@
Ry
@
Rx
return
R_combined
.
float
()
def
move_pose
(
pose1
,
x_angle
,
y_angle
,
z_angle
,
translation
):
"""
Calculate the second camera pose based on the first pose and given rotations (x, y, z) and translation.
pose1: The first camera pose (3x4 matrix).
x_angle, y_angle, z_angle: Rotation angles around the x, y, and z axes, in radians.
translation: Translation vector (3,).
Returns:
pose2: The second camera pose as a (3x4) matrix.
"""
# Extract the rotation (R1) and translation (t1) from the first pose
R1
=
pose1
[:,
:
3
]
t1
=
pose1
[:,
3
]
# Calculate the new rotation matrix from the given angles
R_delta
=
rotation_matrix_from_xyz_angles
(
x_angle
,
y_angle
,
z_angle
)
# New rotation = R1 * R_delta
R2
=
R1
@
R_delta
# New translation = t1 + translation
t2
=
t1
+
translation
# Combine R2 and t2 into the new pose (3x4 matrix)
pose2
=
torch
.
cat
([
R2
,
t2
.
unsqueeze
(
1
)],
dim
=
1
)
return
pose2
def
deg2rad
(
degrees
):
"""Convert degrees to radians."""
return
degrees
*
math
.
pi
/
180.0
def
generate_spherical_trajectory
(
end_angles
,
radius
=
1.0
,
num_steps
=
36
):
"""
Generate a camera-to-world (C2W) trajectory interpolating angles on a sphere.
Args:
end_angles (tuple): The endpoint rotation angles in degrees (x, y, z).
(start is assumed to be (0, 0, 0)).
radius (float): Radius of the sphere.
num_steps (int): Number of steps in the trajectory.
Returns:
torch.Tensor: A tensor of shape [num_steps, 3, 4] with the C2W transformations.
"""
# Convert angles to radians
end_angles_rad
=
torch
.
tensor
(
[
deg2rad
(
angle
)
for
angle
in
end_angles
],
dtype
=
torch
.
float32
)
# Interpolate angles linearly
interpolated_angles
=
(
torch
.
linspace
(
0
,
1
,
num_steps
).
view
(
-
1
,
1
)
*
end_angles_rad
)
# Shape: [num_steps, 3]
poses
=
[]
for
angles
in
interpolated_angles
:
# Extract interpolated angles
x_angle
,
y_angle
=
angles
# Compute camera position on the sphere
x
=
radius
*
math
.
sin
(
y_angle
)
*
math
.
cos
(
x_angle
)
y
=
radius
*
math
.
sin
(
x_angle
)
z
=
radius
*
math
.
cos
(
y_angle
)
*
math
.
cos
(
x_angle
)
cam_position
=
torch
.
tensor
([
x
,
y
,
z
],
dtype
=
torch
.
float32
)
# Camera's forward direction (looking at the origin)
look_at_dir
=
-
cam_position
/
torch
.
norm
(
cam_position
)
# Define the "up" vector
up
=
torch
.
tensor
([
0.0
,
1.0
,
0.0
],
dtype
=
torch
.
float32
)
# Compute the right vector
right
=
torch
.
cross
(
up
,
look_at_dir
)
right
=
right
/
torch
.
norm
(
right
)
# Recompute the orthogonal up vector
up
=
torch
.
cross
(
look_at_dir
,
right
)
# Build the rotation matrix
rotation_matrix
=
torch
.
stack
([
right
,
up
,
look_at_dir
],
dim
=
0
)
# [3, 3]
# Combine the rotation matrix with the translation (camera position)
c2w
=
torch
.
cat
([
rotation_matrix
,
cam_position
.
view
(
3
,
1
)],
dim
=
1
)
# [3, 4]
# Append the pose
poses
.
append
(
c2w
)
return
poses
NVComposer/main/evaluation/utils_eval.py
0 → 100644
View file @
30af93f2
import
torch
def
process_inference_batch
(
cfg_scale
,
batch
,
model
,
with_uncondition_extra
=
False
):
for
k
in
batch
.
keys
():
if
isinstance
(
batch
[
k
],
torch
.
Tensor
):
batch
[
k
]
=
batch
[
k
].
to
(
model
.
device
,
dtype
=
model
.
dtype
)
z
,
cond
,
x_rec
=
model
.
get_batch_input
(
batch
,
random_drop_training_conditions
=
False
,
return_reconstructed_target_images
=
True
,
)
# batch_size = x_rec.shape[0]
# Get unconditioned embedding for classifier-free guidance sampling
if
cfg_scale
!=
1.0
:
uc
=
model
.
get_unconditional_dict_for_sampling
(
batch
,
cond
,
x_rec
)
else
:
uc
=
None
if
with_uncondition_extra
:
uc_extra
=
model
.
get_unconditional_dict_for_sampling
(
batch
,
cond
,
x_rec
,
is_extra
=
True
)
return
cond
,
uc
,
uc_extra
,
x_rec
else
:
return
cond
,
uc
,
x_rec
NVComposer/main/utils_data.py
0 → 100755
View file @
30af93f2
from
utils.utils
import
instantiate_from_config
import
os
import
sys
from
functools
import
partial
import
numpy
as
np
import
pytorch_lightning
as
pl
import
torch
from
torch.utils.data
import
DataLoader
,
Dataset
os
.
chdir
(
sys
.
path
[
0
])
sys
.
path
.
append
(
".."
)
def
t_range
(
name
,
tensor
):
print
(
f
"
{
name
}
: shape=
{
tensor
.
shape
}
, max=
{
torch
.
max
(
tensor
)
}
, min=
{
torch
.
min
(
tensor
)
}
."
)
def
worker_init_fn
(
_
):
worker_info
=
torch
.
utils
.
data
.
get_worker_info
()
worker_id
=
worker_info
.
id
return
np
.
random
.
seed
(
np
.
random
.
get_state
()[
1
][
0
]
+
worker_id
)
class
WrappedDataset
(
Dataset
):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def
__init__
(
self
,
dataset
):
self
.
data
=
dataset
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
idx
):
return
self
.
data
[
idx
]
class
DataModuleFromConfig
(
pl
.
LightningDataModule
):
def
__init__
(
self
,
batch_size
,
train
=
None
,
validation
=
None
,
test
=
None
,
predict
=
None
,
train_img
=
None
,
wrap
=
False
,
num_workers
=
None
,
shuffle_test_loader
=
False
,
use_worker_init_fn
=
False
,
shuffle_val_dataloader
=
False
,
test_max_n_samples
=
None
,
**
kwargs
,
):
super
().
__init__
()
self
.
batch_size
=
batch_size
self
.
dataset_configs
=
dict
()
self
.
num_workers
=
num_workers
if
num_workers
is
not
None
else
batch_size
*
2
self
.
use_worker_init_fn
=
use_worker_init_fn
if
train
is
not
None
:
self
.
dataset_configs
[
"train"
]
=
train
self
.
train_dataloader
=
self
.
_train_dataloader
if
validation
is
not
None
:
self
.
dataset_configs
[
"validation"
]
=
validation
self
.
val_dataloader
=
partial
(
self
.
_val_dataloader
,
shuffle
=
shuffle_val_dataloader
)
if
test
is
not
None
:
self
.
dataset_configs
[
"test"
]
=
test
self
.
test_dataloader
=
partial
(
self
.
_test_dataloader
,
shuffle
=
shuffle_test_loader
)
if
predict
is
not
None
:
self
.
dataset_configs
[
"predict"
]
=
predict
self
.
predict_dataloader
=
self
.
_predict_dataloader
# train image dataset
if
train_img
is
not
None
:
img_data
=
instantiate_from_config
(
train_img
)
self
.
img_loader
=
img_data
.
train_dataloader
()
else
:
self
.
img_loader
=
None
self
.
wrap
=
wrap
self
.
test_max_n_samples
=
test_max_n_samples
self
.
collate_fn
=
None
def
prepare_data
(
self
):
# for data_cfg in self.dataset_configs.values():
# instantiate_from_config(data_cfg)
pass
def
setup
(
self
,
stage
=
None
):
self
.
datasets
=
dict
(
(
k
,
instantiate_from_config
(
self
.
dataset_configs
[
k
]))
for
k
in
self
.
dataset_configs
)
if
self
.
wrap
:
for
k
in
self
.
datasets
:
self
.
datasets
[
k
]
=
WrappedDataset
(
self
.
datasets
[
k
])
def
_train_dataloader
(
self
):
is_iterable_dataset
=
False
if
is_iterable_dataset
or
self
.
use_worker_init_fn
:
init_fn
=
worker_init_fn
else
:
init_fn
=
None
loader
=
DataLoader
(
self
.
datasets
[
"train"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
shuffle
=
False
if
is_iterable_dataset
else
True
,
worker_init_fn
=
init_fn
,
collate_fn
=
self
.
collate_fn
,
)
if
self
.
img_loader
is
not
None
:
return
{
"loader_video"
:
loader
,
"loader_img"
:
self
.
img_loader
}
else
:
return
loader
def
_val_dataloader
(
self
,
shuffle
=
False
):
init_fn
=
None
return
DataLoader
(
self
.
datasets
[
"validation"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
,
shuffle
=
shuffle
,
collate_fn
=
self
.
collate_fn
,
)
def
_test_dataloader
(
self
,
shuffle
=
False
):
is_iterable_dataset
=
False
if
is_iterable_dataset
or
self
.
use_worker_init_fn
:
init_fn
=
worker_init_fn
else
:
init_fn
=
None
# do not shuffle dataloader for iterable dataset
shuffle
=
shuffle
and
(
not
is_iterable_dataset
)
if
self
.
test_max_n_samples
is
not
None
:
dataset
=
torch
.
utils
.
data
.
Subset
(
self
.
datasets
[
"test"
],
list
(
range
(
self
.
test_max_n_samples
))
)
else
:
dataset
=
self
.
datasets
[
"test"
]
return
DataLoader
(
dataset
,
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
,
shuffle
=
shuffle
,
collate_fn
=
self
.
collate_fn
,
)
def
_predict_dataloader
(
self
,
shuffle
=
False
):
init_fn
=
None
return
DataLoader
(
self
.
datasets
[
"predict"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
,
collate_fn
=
self
.
collate_fn
,
)
NVComposer/requirements.txt
0 → 100755
View file @
30af93f2
pytorch_lightning
deepspeed
taming-transformers
scipy
einops
kornia
open_clip_torch
openai-clip
xformers
timm
av
gradio
\ No newline at end of file
NVComposer/utils/constants.py
0 → 100755
View file @
30af93f2
FLAG_RUN_DEBUG
=
False
PATH_DIR_DEBUG
=
"./debug/"
NVComposer/utils/load_weigths.py
0 → 100755
View file @
30af93f2
from
utils.utils
import
instantiate_from_config
import
torch
import
copy
from
omegaconf
import
OmegaConf
import
logging
main_logger
=
logging
.
getLogger
(
"main_logger"
)
def
expand_conv_kernel
(
pretrained_dict
):
"""expand 2d conv parameters from 4D -> 5D"""
for
k
,
v
in
pretrained_dict
.
items
():
if
v
.
dim
()
==
4
and
not
k
.
startswith
(
"first_stage_model"
):
v
=
v
.
unsqueeze
(
2
)
pretrained_dict
[
k
]
=
v
return
pretrained_dict
def
print_state_dict
(
state_dict
):
print
(
"====== Dumping State Dict ======"
)
for
k
,
v
in
state_dict
.
items
():
print
(
k
,
v
.
shape
)
def
load_from_pretrainedSD_checkpoint
(
model
,
pretained_ckpt
,
expand_to_3d
=
True
,
adapt_keyname
=
False
,
echo_empty_params
=
False
,
):
sd_state_dict
=
torch
.
load
(
pretained_ckpt
,
map_location
=
"cpu"
)
if
"state_dict"
in
list
(
sd_state_dict
.
keys
()):
sd_state_dict
=
sd_state_dict
[
"state_dict"
]
model_state_dict
=
model
.
state_dict
()
# delete ema_weights just for <precise param counting>
for
k
in
list
(
sd_state_dict
.
keys
()):
if
k
.
startswith
(
"model_ema"
):
del
sd_state_dict
[
k
]
main_logger
.
info
(
f
"Num of model params of Source:
{
len
(
sd_state_dict
.
keys
())
}
VS. Target:
{
len
(
model_state_dict
.
keys
())
}
"
)
# print_state_dict(model_state_dict)
# print_state_dict(sd_state_dict)
if
adapt_keyname
:
# adapting to standard 2d network: modify the key name because of the add of temporal-attention
mapping_dict
=
{
"middle_block.2"
:
"middle_block.3"
,
"output_blocks.5.2"
:
"output_blocks.5.3"
,
"output_blocks.8.2"
:
"output_blocks.8.3"
,
}
cnt
=
0
for
k
in
list
(
sd_state_dict
.
keys
()):
for
src_word
,
dst_word
in
mapping_dict
.
items
():
if
src_word
in
k
:
new_key
=
k
.
replace
(
src_word
,
dst_word
)
sd_state_dict
[
new_key
]
=
sd_state_dict
[
k
]
del
sd_state_dict
[
k
]
cnt
+=
1
main_logger
.
info
(
f
"[renamed
{
cnt
}
Source keys to match Target model]"
)
pretrained_dict
=
{
k
:
v
for
k
,
v
in
sd_state_dict
.
items
()
if
k
in
model_state_dict
}
# drop extra keys
empty_paras
=
[
k
for
k
,
v
in
model_state_dict
.
items
()
if
k
not
in
pretrained_dict
]
# log no pretrained keys
assert
len
(
empty_paras
)
+
len
(
pretrained_dict
.
keys
())
==
len
(
model_state_dict
.
keys
()
)
if
expand_to_3d
:
# adapting to 2d inflated network
pretrained_dict
=
expand_conv_kernel
(
pretrained_dict
)
# overwrite entries in the existing state dict
model_state_dict
.
update
(
pretrained_dict
)
# load the new state dict
try
:
model
.
load_state_dict
(
model_state_dict
)
except
:
skipped
=
[]
model_dict_ori
=
model
.
state_dict
()
for
n
,
p
in
model_state_dict
.
items
():
if
p
.
shape
!=
model_dict_ori
[
n
].
shape
:
# skip by using original empty paras
model_state_dict
[
n
]
=
model_dict_ori
[
n
]
main_logger
.
info
(
f
"Skip para:
{
n
}
, size=
{
pretrained_dict
[
n
].
shape
}
in pretrained,
{
model_state_dict
[
n
].
shape
}
in current model"
)
skipped
.
append
(
n
)
main_logger
.
info
(
f
"[INFO] Skip
{
len
(
skipped
)
}
parameters becasuse of size mismatch!"
)
model
.
load_state_dict
(
model_state_dict
)
empty_paras
+=
skipped
# only count Unet part of depth estimation model
unet_empty_paras
=
[
name
for
name
in
empty_paras
if
name
.
startswith
(
"model.diffusion_model"
)
]
main_logger
.
info
(
f
"Pretrained parameters:
{
len
(
pretrained_dict
.
keys
())
}
| Empty parameters:
{
len
(
empty_paras
)
}
[Unet:
{
len
(
unet_empty_paras
)
}
]"
)
if
echo_empty_params
:
print
(
"Printing empty parameters:"
)
for
k
in
empty_paras
:
print
(
k
)
return
model
,
empty_paras
# Below: written by Yingqing --------------------------------------------------------
def
load_model_from_config
(
config
,
ckpt
,
verbose
=
False
):
pl_sd
=
torch
.
load
(
ckpt
,
map_location
=
"cpu"
)
sd
=
pl_sd
[
"state_dict"
]
model
=
instantiate_from_config
(
config
.
model
)
m
,
u
=
model
.
load_state_dict
(
sd
,
strict
=
False
)
if
len
(
m
)
>
0
and
verbose
:
main_logger
.
info
(
"missing keys:"
)
main_logger
.
info
(
m
)
if
len
(
u
)
>
0
and
verbose
:
main_logger
.
info
(
"unexpected keys:"
)
main_logger
.
info
(
u
)
model
.
eval
()
return
model
def
init_and_load_ldm_model
(
config_path
,
ckpt_path
,
device
=
None
):
assert
config_path
.
endswith
(
".yaml"
),
f
"config_path =
{
config_path
}
"
assert
ckpt_path
.
endswith
(
".ckpt"
),
f
"ckpt_path =
{
ckpt_path
}
"
config
=
OmegaConf
.
load
(
config_path
)
model
=
load_model_from_config
(
config
,
ckpt_path
)
if
device
is
not
None
:
model
=
model
.
to
(
device
)
return
model
def
load_img_model_to_video_model
(
model
,
device
=
None
,
expand_to_3d
=
True
,
adapt_keyname
=
False
,
config_path
=
"configs/latent-diffusion/txt2img-1p4B-eval.yaml"
,
ckpt_path
=
"models/ldm/text2img-large/model.ckpt"
,
):
pretrained_ldm
=
init_and_load_ldm_model
(
config_path
,
ckpt_path
,
device
)
model
,
empty_paras
=
load_partial_weights
(
model
,
pretrained_ldm
.
state_dict
(),
expand_to_3d
=
expand_to_3d
,
adapt_keyname
=
adapt_keyname
,
)
return
model
,
empty_paras
def
load_partial_weights
(
model
,
pretrained_dict
,
expand_to_3d
=
True
,
adapt_keyname
=
False
):
model2
=
copy
.
deepcopy
(
model
)
model_dict
=
model
.
state_dict
()
model_dict_ori
=
copy
.
deepcopy
(
model_dict
)
main_logger
.
info
(
f
"[Load pretrained LDM weights]"
)
main_logger
.
info
(
f
"Num of parameters of source model:
{
len
(
pretrained_dict
.
keys
())
}
VS. target model:
{
len
(
model_dict
.
keys
())
}
"
)
if
adapt_keyname
:
# adapting to menghan's standard 2d network: modify the key name because of the add of temporal-attention
mapping_dict
=
{
"middle_block.2"
:
"middle_block.3"
,
"output_blocks.5.2"
:
"output_blocks.5.3"
,
"output_blocks.8.2"
:
"output_blocks.8.3"
,
}
cnt
=
0
newpretrained_dict
=
copy
.
deepcopy
(
pretrained_dict
)
for
k
,
v
in
newpretrained_dict
.
items
():
for
src_word
,
dst_word
in
mapping_dict
.
items
():
if
src_word
in
k
:
new_key
=
k
.
replace
(
src_word
,
dst_word
)
pretrained_dict
[
new_key
]
=
v
pretrained_dict
.
pop
(
k
)
cnt
+=
1
main_logger
.
info
(
f
"--renamed
{
cnt
}
source keys to match target model."
)
pretrained_dict
=
{
k
:
v
for
k
,
v
in
pretrained_dict
.
items
()
if
k
in
model_dict
}
# drop extra keys
empty_paras
=
[
k
for
k
,
v
in
model_dict
.
items
()
if
k
not
in
pretrained_dict
]
# log no pretrained keys
main_logger
.
info
(
f
"Pretrained parameters:
{
len
(
pretrained_dict
.
keys
())
}
| Empty parameters:
{
len
(
empty_paras
)
}
"
)
# disable info
# main_logger.info(f'Empty parameters: {empty_paras} ')
assert
len
(
empty_paras
)
+
len
(
pretrained_dict
.
keys
())
==
len
(
model_dict
.
keys
())
if
expand_to_3d
:
# adapting to yingqing's 2d inflation network
pretrained_dict
=
expand_conv_kernel
(
pretrained_dict
)
# overwrite entries in the existing state dict
model_dict
.
update
(
pretrained_dict
)
# load the new state dict
try
:
model2
.
load_state_dict
(
model_dict
)
except
:
# if parameter size mismatch, skip them
skipped
=
[]
for
n
,
p
in
model_dict
.
items
():
if
p
.
shape
!=
model_dict_ori
[
n
].
shape
:
# skip by using original empty paras
model_dict
[
n
]
=
model_dict_ori
[
n
]
main_logger
.
info
(
f
"Skip para:
{
n
}
, size=
{
pretrained_dict
[
n
].
shape
}
in pretrained,
{
model_dict
[
n
].
shape
}
in current model"
)
skipped
.
append
(
n
)
main_logger
.
info
(
f
"[INFO] Skip
{
len
(
skipped
)
}
parameters becasuse of size mismatch!"
)
model2
.
load_state_dict
(
model_dict
)
empty_paras
+=
skipped
main_logger
.
info
(
f
"Empty parameters:
{
len
(
empty_paras
)
}
"
)
main_logger
.
info
(
f
"Finished."
)
return
model2
,
empty_paras
def
load_autoencoder
(
model
,
config_path
=
None
,
ckpt_path
=
None
,
device
=
None
):
if
config_path
is
None
:
config_path
=
"configs/latent-diffusion/txt2img-1p4B-eval.yaml"
if
ckpt_path
is
None
:
ckpt_path
=
"models/ldm/text2img-large/model.ckpt"
pretrained_ldm
=
init_and_load_ldm_model
(
config_path
,
ckpt_path
,
device
)
autoencoder_dict
=
{}
for
n
,
p
in
pretrained_ldm
.
state_dict
().
items
():
if
n
.
startswith
(
"first_stage_model"
):
autoencoder_dict
[
n
]
=
p
model_dict
=
model
.
state_dict
()
model_dict
.
update
(
autoencoder_dict
)
main_logger
.
info
(
f
"Load [
{
len
(
autoencoder_dict
)
}
] autoencoder parameters!"
)
model
.
load_state_dict
(
model_dict
)
return
model
NVComposer/utils/lr_scheduler.py
0 → 100755
View file @
30af93f2
import
numpy
as
np
import
torch
import
torch.optim
as
optim
def
build_LR_scheduler
(
optimizer
,
scheduler_name
,
lr_decay_ratio
,
max_epochs
,
start_epoch
=
0
):
# print("-LR scheduler:%s"%scheduler_name)
if
scheduler_name
==
"LambdaLR"
:
decay_ratio
=
lr_decay_ratio
decay_epochs
=
max_epochs
def
polynomial_decay
(
epoch
):
return
(
1
+
(
decay_ratio
-
1
)
*
((
epoch
+
start_epoch
)
/
decay_epochs
)
if
(
epoch
+
start_epoch
)
<
decay_epochs
else
decay_ratio
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
LambdaLR
(
optimizer
,
lr_lambda
=
polynomial_decay
)
elif
scheduler_name
==
"CosineAnnealingLR"
:
last_epoch
=
-
1
if
start_epoch
==
0
else
start_epoch
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
T_max
=
max_epochs
,
last_epoch
=
last_epoch
)
elif
scheduler_name
==
"ReduceLROnPlateau"
:
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
mode
=
"min"
,
factor
=
0.5
,
threshold
=
0.01
,
patience
=
5
)
else
:
raise
NotImplementedError
return
lr_scheduler
class
LambdaLRScheduler
:
# target: torch.optim.lr_scheduler.LambdaLR
def
__init__
(
self
,
start_step
,
final_decay_ratio
,
decay_steps
):
self
.
final_decay_ratio
=
final_decay_ratio
self
.
decay_steps
=
decay_steps
self
.
start_step
=
start_step
def
schedule
(
self
,
step
):
if
step
+
self
.
start_step
<
self
.
decay_steps
:
return
1.0
+
(
self
.
final_decay_ratio
-
1
)
*
(
(
step
+
self
.
start_step
)
/
self
.
decay_steps
)
else
:
return
self
.
final_decay_ratio
def
__call__
(
self
,
step
):
return
self
.
scheduler
(
step
)
class
CosineAnnealingLRScheduler
:
# target: torch.optim.lr_scheduler.CosineAnnealingLR
def
__init__
(
self
,
start_step
,
decay_steps
):
self
.
decay_steps
=
decay_steps
self
.
start_step
=
start_step
def
__call__
(
self
,
step
):
pass
class
LambdaWarmUpCosineScheduler
:
"""
note: use with a base_lr of 1.0
"""
def
__init__
(
self
,
warm_up_steps
,
lr_min
,
lr_max
,
lr_start
,
max_decay_steps
,
verbosity_interval
=
0
,
):
self
.
lr_warm_up_steps
=
warm_up_steps
self
.
lr_start
=
lr_start
self
.
lr_min
=
lr_min
self
.
lr_max
=
lr_max
self
.
lr_max_decay_steps
=
max_decay_steps
self
.
last_lr
=
0.0
self
.
verbosity_interval
=
verbosity_interval
def
schedule
(
self
,
n
,
**
kwargs
):
if
self
.
verbosity_interval
>
0
:
if
n
%
self
.
verbosity_interval
==
0
:
print
(
f
"current step:
{
n
}
, recent lr-multiplier:
{
self
.
last_lr
}
"
)
if
n
<
self
.
lr_warm_up_steps
:
lr
=
(
self
.
lr_max
-
self
.
lr_start
)
/
self
.
lr_warm_up_steps
*
n
+
self
.
lr_start
self
.
last_lr
=
lr
return
lr
else
:
t
=
(
n
-
self
.
lr_warm_up_steps
)
/
(
self
.
lr_max_decay_steps
-
self
.
lr_warm_up_steps
)
t
=
min
(
t
,
1.0
)
lr
=
self
.
lr_min
+
0.5
*
(
self
.
lr_max
-
self
.
lr_min
)
*
(
1
+
np
.
cos
(
t
*
np
.
pi
)
)
self
.
last_lr
=
lr
return
lr
def
__call__
(
self
,
n
,
**
kwargs
):
return
self
.
schedule
(
n
,
**
kwargs
)
class
LambdaWarmUpCosineScheduler2
:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def
__init__
(
self
,
warm_up_steps
,
f_min
,
f_max
,
f_start
,
cycle_lengths
,
verbosity_interval
=
0
):
assert
(
len
(
warm_up_steps
)
==
len
(
f_min
)
==
len
(
f_max
)
==
len
(
f_start
)
==
len
(
cycle_lengths
)
)
self
.
lr_warm_up_steps
=
warm_up_steps
self
.
f_start
=
f_start
self
.
f_min
=
f_min
self
.
f_max
=
f_max
self
.
cycle_lengths
=
cycle_lengths
self
.
cum_cycles
=
np
.
cumsum
([
0
]
+
list
(
self
.
cycle_lengths
))
self
.
last_f
=
0.0
self
.
verbosity_interval
=
verbosity_interval
def
find_in_interval
(
self
,
n
):
interval
=
0
for
cl
in
self
.
cum_cycles
[
1
:]:
if
n
<=
cl
:
return
interval
interval
+=
1
def
schedule
(
self
,
n
,
**
kwargs
):
cycle
=
self
.
find_in_interval
(
n
)
n
=
n
-
self
.
cum_cycles
[
cycle
]
if
self
.
verbosity_interval
>
0
:
if
n
%
self
.
verbosity_interval
==
0
:
print
(
f
"current step:
{
n
}
, recent lr-multiplier:
{
self
.
last_f
}
, "
f
"current cycle
{
cycle
}
"
)
if
n
<
self
.
lr_warm_up_steps
[
cycle
]:
f
=
(
self
.
f_max
[
cycle
]
-
self
.
f_start
[
cycle
])
/
self
.
lr_warm_up_steps
[
cycle
]
*
n
+
self
.
f_start
[
cycle
]
self
.
last_f
=
f
return
f
else
:
t
=
(
n
-
self
.
lr_warm_up_steps
[
cycle
])
/
(
self
.
cycle_lengths
[
cycle
]
-
self
.
lr_warm_up_steps
[
cycle
]
)
t
=
min
(
t
,
1.0
)
f
=
self
.
f_min
[
cycle
]
+
0.5
*
(
self
.
f_max
[
cycle
]
-
self
.
f_min
[
cycle
])
*
(
1
+
np
.
cos
(
t
*
np
.
pi
)
)
self
.
last_f
=
f
return
f
def
__call__
(
self
,
n
,
**
kwargs
):
return
self
.
schedule
(
n
,
**
kwargs
)
class
LambdaLinearScheduler
(
LambdaWarmUpCosineScheduler2
):
def
schedule
(
self
,
n
,
**
kwargs
):
cycle
=
self
.
find_in_interval
(
n
)
n
=
n
-
self
.
cum_cycles
[
cycle
]
if
self
.
verbosity_interval
>
0
:
if
n
%
self
.
verbosity_interval
==
0
:
print
(
f
"current step:
{
n
}
, recent lr-multiplier:
{
self
.
last_f
}
, "
f
"current cycle
{
cycle
}
"
)
if
n
<
self
.
lr_warm_up_steps
[
cycle
]:
f
=
(
self
.
f_max
[
cycle
]
-
self
.
f_start
[
cycle
])
/
self
.
lr_warm_up_steps
[
cycle
]
*
n
+
self
.
f_start
[
cycle
]
self
.
last_f
=
f
return
f
else
:
f
=
self
.
f_min
[
cycle
]
+
(
self
.
f_max
[
cycle
]
-
self
.
f_min
[
cycle
])
*
(
self
.
cycle_lengths
[
cycle
]
-
n
)
/
(
self
.
cycle_lengths
[
cycle
])
self
.
last_f
=
f
return
f
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment