Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Pyramid-Flow_pytorch
Commits
0e56f303
Commit
0e56f303
authored
Nov 29, 2024
by
mashun
Browse files
pyramid-flow
parents
Pipeline
#2007
canceled with stages
Changes
85
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1591 additions
and
0 deletions
+1591
-0
video_vae/modeling_discriminator.py
video_vae/modeling_discriminator.py
+123
-0
video_vae/modeling_enc_dec.py
video_vae/modeling_enc_dec.py
+423
-0
video_vae/modeling_loss.py
video_vae/modeling_loss.py
+192
-0
video_vae/modeling_lpips.py
video_vae/modeling_lpips.py
+123
-0
video_vae/modeling_resnet.py
video_vae/modeling_resnet.py
+730
-0
No files found.
video_vae/modeling_discriminator.py
0 → 100644
View file @
0e56f303
import
functools
import
torch.nn
as
nn
from
einops
import
rearrange
import
torch
def
weights_init
(
m
):
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'Conv'
)
!=
-
1
:
nn
.
init
.
normal_
(
m
.
weight
.
data
,
0.0
,
0.02
)
nn
.
init
.
constant_
(
m
.
bias
.
data
,
0
)
elif
classname
.
find
(
'BatchNorm'
)
!=
-
1
:
nn
.
init
.
normal_
(
m
.
weight
.
data
,
1.0
,
0.02
)
nn
.
init
.
constant_
(
m
.
bias
.
data
,
0
)
class
NLayerDiscriminator
(
nn
.
Module
):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def
__init__
(
self
,
input_nc
=
3
,
ndf
=
64
,
n_layers
=
4
):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super
(
NLayerDiscriminator
,
self
).
__init__
()
# norm_layer = nn.BatchNorm2d
norm_layer
=
nn
.
InstanceNorm2d
if
type
(
norm_layer
)
==
functools
.
partial
:
# no need to use bias as BatchNorm2d has affine parameters
use_bias
=
norm_layer
.
func
!=
nn
.
BatchNorm2d
else
:
use_bias
=
norm_layer
!=
nn
.
BatchNorm2d
kw
=
4
padw
=
1
sequence
=
[
nn
.
Conv2d
(
input_nc
,
ndf
,
kernel_size
=
kw
,
stride
=
2
,
padding
=
padw
),
nn
.
LeakyReLU
(
0.2
,
True
)]
nf_mult
=
1
nf_mult_prev
=
1
for
n
in
range
(
1
,
n_layers
):
# gradually increase the number of filters
nf_mult_prev
=
nf_mult
nf_mult
=
min
(
2
**
n
,
8
)
sequence
+=
[
nn
.
Conv2d
(
ndf
*
nf_mult_prev
,
ndf
*
nf_mult
,
kernel_size
=
kw
,
stride
=
2
,
padding
=
padw
,
bias
=
use_bias
),
norm_layer
(
ndf
*
nf_mult
),
nn
.
LeakyReLU
(
0.2
,
True
)
]
nf_mult_prev
=
nf_mult
nf_mult
=
min
(
2
**
n_layers
,
8
)
sequence
+=
[
nn
.
Conv2d
(
ndf
*
nf_mult_prev
,
ndf
*
nf_mult
,
kernel_size
=
kw
,
stride
=
1
,
padding
=
padw
,
bias
=
use_bias
),
norm_layer
(
ndf
*
nf_mult
),
nn
.
LeakyReLU
(
0.2
,
True
)
]
sequence
+=
[
nn
.
Conv2d
(
ndf
*
nf_mult
,
1
,
kernel_size
=
kw
,
stride
=
1
,
padding
=
padw
)]
# output 1 channel prediction map
self
.
main
=
nn
.
Sequential
(
*
sequence
)
def
forward
(
self
,
input
):
"""Standard forward."""
return
self
.
main
(
input
)
class
NLayerDiscriminator3D
(
nn
.
Module
):
"""Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs."""
def
__init__
(
self
,
input_nc
=
3
,
ndf
=
64
,
n_layers
=
3
,
use_actnorm
=
False
):
"""
Construct a 3D PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input volumes
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
use_actnorm (bool) -- flag to use actnorm instead of batchnorm
"""
super
(
NLayerDiscriminator3D
,
self
).
__init__
()
# if not use_actnorm:
# norm_layer = nn.BatchNorm3d
# else:
# raise NotImplementedError("Not implemented.")
norm_layer
=
nn
.
InstanceNorm3d
if
type
(
norm_layer
)
==
functools
.
partial
:
use_bias
=
norm_layer
.
func
!=
nn
.
BatchNorm3d
else
:
use_bias
=
norm_layer
!=
nn
.
BatchNorm3d
kw
=
4
padw
=
1
sequence
=
[
nn
.
Conv3d
(
input_nc
,
ndf
,
kernel_size
=
kw
,
stride
=
2
,
padding
=
padw
),
nn
.
LeakyReLU
(
0.2
,
True
)]
nf_mult
=
1
nf_mult_prev
=
1
for
n
in
range
(
1
,
n_layers
):
# gradually increase the number of filters
nf_mult_prev
=
nf_mult
nf_mult
=
min
(
2
**
n
,
8
)
sequence
+=
[
nn
.
Conv3d
(
ndf
*
nf_mult_prev
,
ndf
*
nf_mult
,
kernel_size
=
(
kw
,
kw
,
kw
),
stride
=
(
1
,
2
,
2
),
padding
=
padw
,
bias
=
use_bias
),
norm_layer
(
ndf
*
nf_mult
),
nn
.
LeakyReLU
(
0.2
,
True
)
]
nf_mult_prev
=
nf_mult
nf_mult
=
min
(
2
**
n_layers
,
8
)
sequence
+=
[
nn
.
Conv3d
(
ndf
*
nf_mult_prev
,
ndf
*
nf_mult
,
kernel_size
=
(
kw
,
kw
,
kw
),
stride
=
1
,
padding
=
padw
,
bias
=
use_bias
),
norm_layer
(
ndf
*
nf_mult
),
nn
.
LeakyReLU
(
0.2
,
True
)
]
sequence
+=
[
nn
.
Conv3d
(
ndf
*
nf_mult
,
1
,
kernel_size
=
kw
,
stride
=
1
,
padding
=
padw
)]
# output 1 channel prediction map
self
.
main
=
nn
.
Sequential
(
*
sequence
)
def
forward
(
self
,
input
):
"""Standard forward."""
return
self
.
main
(
input
)
\ No newline at end of file
video_vae/modeling_enc_dec.py
0 → 100644
View file @
0e56f303
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
diffusers.utils
import
BaseOutput
,
is_torch_version
from
diffusers.utils.torch_utils
import
randn_tensor
from
diffusers.models.attention_processor
import
SpatialNorm
from
.modeling_block
import
(
UNetMidBlock2D
,
CausalUNetMidBlock2D
,
get_down_block
,
get_up_block
,
get_input_layer
,
get_output_layer
,
)
from
.modeling_resnet
import
(
Downsample2D
,
Upsample2D
,
TemporalDownsample2x
,
TemporalUpsample2x
,
)
from
.modeling_causal_conv
import
CausalConv3d
,
CausalGroupNorm
@
dataclass
class
DecoderOutput
(
BaseOutput
):
r
"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The decoded output sample from the last layer of the model.
"""
sample
:
torch
.
FloatTensor
class
CausalVaeEncoder
(
nn
.
Module
):
r
"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
"""
def
__init__
(
self
,
in_channels
:
int
=
3
,
out_channels
:
int
=
3
,
down_block_types
:
Tuple
[
str
,
...]
=
(
"DownEncoderBlockCausal3D"
,),
spatial_down_sample
:
Tuple
[
bool
,
...]
=
(
True
,),
temporal_down_sample
:
Tuple
[
bool
,
...]
=
(
False
,),
block_out_channels
:
Tuple
[
int
,
...]
=
(
64
,),
layers_per_block
:
Tuple
[
int
,
...]
=
(
2
,),
norm_num_groups
:
int
=
32
,
act_fn
:
str
=
"silu"
,
double_z
:
bool
=
True
,
block_dropout
:
Tuple
[
int
,
...]
=
(
0.0
,),
mid_block_add_attention
=
True
,
):
super
().
__init__
()
self
.
layers_per_block
=
layers_per_block
self
.
conv_in
=
CausalConv3d
(
in_channels
,
block_out_channels
[
0
],
kernel_size
=
3
,
stride
=
1
,
)
self
.
mid_block
=
None
self
.
down_blocks
=
nn
.
ModuleList
([])
# down
output_channel
=
block_out_channels
[
0
]
for
i
,
down_block_type
in
enumerate
(
down_block_types
):
input_channel
=
output_channel
output_channel
=
block_out_channels
[
i
]
down_block
=
get_down_block
(
down_block_type
,
num_layers
=
self
.
layers_per_block
[
i
],
in_channels
=
input_channel
,
out_channels
=
output_channel
,
add_spatial_downsample
=
spatial_down_sample
[
i
],
add_temporal_downsample
=
temporal_down_sample
[
i
],
resnet_eps
=
1e-6
,
downsample_padding
=
0
,
resnet_act_fn
=
act_fn
,
resnet_groups
=
norm_num_groups
,
attention_head_dim
=
output_channel
,
temb_channels
=
None
,
dropout
=
block_dropout
[
i
],
)
self
.
down_blocks
.
append
(
down_block
)
# mid
self
.
mid_block
=
CausalUNetMidBlock2D
(
in_channels
=
block_out_channels
[
-
1
],
resnet_eps
=
1e-6
,
resnet_act_fn
=
act_fn
,
output_scale_factor
=
1
,
resnet_time_scale_shift
=
"default"
,
attention_head_dim
=
block_out_channels
[
-
1
],
resnet_groups
=
norm_num_groups
,
temb_channels
=
None
,
add_attention
=
mid_block_add_attention
,
dropout
=
block_dropout
[
-
1
],
)
# out
self
.
conv_norm_out
=
CausalGroupNorm
(
num_channels
=
block_out_channels
[
-
1
],
num_groups
=
norm_num_groups
,
eps
=
1e-6
)
self
.
conv_act
=
nn
.
SiLU
()
conv_out_channels
=
2
*
out_channels
if
double_z
else
out_channels
self
.
conv_out
=
CausalConv3d
(
block_out_channels
[
-
1
],
conv_out_channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
is_init_image
=
True
,
temporal_chunk
=
False
)
->
torch
.
FloatTensor
:
r
"""The forward method of the `Encoder` class."""
sample
=
self
.
conv_in
(
sample
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
)
if
self
.
training
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
)
return
custom_forward
# down
if
is_torch_version
(
">="
,
"1.11.0"
):
for
down_block
in
self
.
down_blocks
:
sample
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
down_block
),
sample
,
is_init_image
,
temporal_chunk
,
use_reentrant
=
False
)
# middle
sample
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
self
.
mid_block
),
sample
,
is_init_image
,
temporal_chunk
,
use_reentrant
=
False
)
else
:
for
down_block
in
self
.
down_blocks
:
sample
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
down_block
),
sample
,
is_init_image
,
temporal_chunk
)
# middle
sample
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
self
.
mid_block
),
sample
,
is_init_image
,
temporal_chunk
)
else
:
# down
for
down_block
in
self
.
down_blocks
:
sample
=
down_block
(
sample
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
)
# middle
sample
=
self
.
mid_block
(
sample
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
)
# post-process
sample
=
self
.
conv_norm_out
(
sample
)
sample
=
self
.
conv_act
(
sample
)
sample
=
self
.
conv_out
(
sample
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
)
return
sample
class
CausalVaeDecoder
(
nn
.
Module
):
r
"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
def
__init__
(
self
,
in_channels
:
int
=
3
,
out_channels
:
int
=
3
,
up_block_types
:
Tuple
[
str
,
...]
=
(
"UpDecoderBlockCausal3D"
,),
spatial_up_sample
:
Tuple
[
bool
,
...]
=
(
True
,),
temporal_up_sample
:
Tuple
[
bool
,
...]
=
(
False
,),
block_out_channels
:
Tuple
[
int
,
...]
=
(
64
,),
layers_per_block
:
Tuple
[
int
,
...]
=
(
2
,),
norm_num_groups
:
int
=
32
,
act_fn
:
str
=
"silu"
,
mid_block_add_attention
=
True
,
interpolate
:
bool
=
True
,
block_dropout
:
Tuple
[
int
,
...]
=
(
0.0
,),
):
super
().
__init__
()
self
.
layers_per_block
=
layers_per_block
self
.
conv_in
=
CausalConv3d
(
in_channels
,
block_out_channels
[
-
1
],
kernel_size
=
3
,
stride
=
1
,
)
self
.
mid_block
=
None
self
.
up_blocks
=
nn
.
ModuleList
([])
# mid
self
.
mid_block
=
CausalUNetMidBlock2D
(
in_channels
=
block_out_channels
[
-
1
],
resnet_eps
=
1e-6
,
resnet_act_fn
=
act_fn
,
output_scale_factor
=
1
,
resnet_time_scale_shift
=
"default"
,
attention_head_dim
=
block_out_channels
[
-
1
],
resnet_groups
=
norm_num_groups
,
temb_channels
=
None
,
add_attention
=
mid_block_add_attention
,
dropout
=
block_dropout
[
-
1
],
)
# up
reversed_block_out_channels
=
list
(
reversed
(
block_out_channels
))
output_channel
=
reversed_block_out_channels
[
0
]
for
i
,
up_block_type
in
enumerate
(
up_block_types
):
prev_output_channel
=
output_channel
output_channel
=
reversed_block_out_channels
[
i
]
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
up_block
=
get_up_block
(
up_block_type
,
num_layers
=
self
.
layers_per_block
[
i
],
in_channels
=
prev_output_channel
,
out_channels
=
output_channel
,
prev_output_channel
=
None
,
add_spatial_upsample
=
spatial_up_sample
[
i
],
add_temporal_upsample
=
temporal_up_sample
[
i
],
resnet_eps
=
1e-6
,
resnet_act_fn
=
act_fn
,
resnet_groups
=
norm_num_groups
,
attention_head_dim
=
output_channel
,
temb_channels
=
None
,
resnet_time_scale_shift
=
'default'
,
interpolate
=
interpolate
,
dropout
=
block_dropout
[
i
],
)
self
.
up_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
# out
self
.
conv_norm_out
=
CausalGroupNorm
(
num_channels
=
block_out_channels
[
0
],
num_groups
=
norm_num_groups
,
eps
=
1e-6
)
self
.
conv_act
=
nn
.
SiLU
()
self
.
conv_out
=
CausalConv3d
(
block_out_channels
[
0
],
out_channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
is_init_image
=
True
,
temporal_chunk
=
False
,
)
->
torch
.
FloatTensor
:
r
"""The forward method of the `Decoder` class."""
sample
=
self
.
conv_in
(
sample
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
)
upscale_dtype
=
next
(
iter
(
self
.
up_blocks
.
parameters
())).
dtype
if
self
.
training
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
)
return
custom_forward
if
is_torch_version
(
">="
,
"1.11.0"
):
# middle
sample
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
self
.
mid_block
),
sample
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
,
use_reentrant
=
False
,
)
sample
=
sample
.
to
(
upscale_dtype
)
# up
for
up_block
in
self
.
up_blocks
:
sample
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
up_block
),
sample
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
,
use_reentrant
=
False
,
)
else
:
# middle
sample
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
self
.
mid_block
),
sample
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
,
)
sample
=
sample
.
to
(
upscale_dtype
)
# up
for
up_block
in
self
.
up_blocks
:
sample
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
up_block
),
sample
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
,)
else
:
# middle
sample
=
self
.
mid_block
(
sample
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
)
sample
=
sample
.
to
(
upscale_dtype
)
# up
for
up_block
in
self
.
up_blocks
:
sample
=
up_block
(
sample
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
,)
# post-process
sample
=
self
.
conv_norm_out
(
sample
)
sample
=
self
.
conv_act
(
sample
)
sample
=
self
.
conv_out
(
sample
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
)
return
sample
class
DiagonalGaussianDistribution
(
object
):
def
__init__
(
self
,
parameters
:
torch
.
Tensor
,
deterministic
:
bool
=
False
):
self
.
parameters
=
parameters
self
.
mean
,
self
.
logvar
=
torch
.
chunk
(
parameters
,
2
,
dim
=
1
)
self
.
logvar
=
torch
.
clamp
(
self
.
logvar
,
-
30.0
,
20.0
)
self
.
deterministic
=
deterministic
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
if
self
.
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
,
device
=
self
.
parameters
.
device
,
dtype
=
self
.
parameters
.
dtype
)
def
sample
(
self
,
generator
:
Optional
[
torch
.
Generator
]
=
None
)
->
torch
.
FloatTensor
:
# make sure sample is on the same device as the parameters and has same dtype
sample
=
randn_tensor
(
self
.
mean
.
shape
,
generator
=
generator
,
device
=
self
.
parameters
.
device
,
dtype
=
self
.
parameters
.
dtype
,
)
x
=
self
.
mean
+
self
.
std
*
sample
return
x
def
kl
(
self
,
other
:
"DiagonalGaussianDistribution"
=
None
)
->
torch
.
Tensor
:
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
else
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
2
,
3
,
4
],
)
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
2
,
3
,
4
],
)
def
nll
(
self
,
sample
:
torch
.
Tensor
,
dims
:
Tuple
[
int
,
...]
=
[
1
,
2
,
3
])
->
torch
.
Tensor
:
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
,
)
def
mode
(
self
)
->
torch
.
Tensor
:
return
self
.
mean
\ No newline at end of file
video_vae/modeling_loss.py
0 → 100644
View file @
0e56f303
import
os
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
.modeling_lpips
import
LPIPS
from
.modeling_discriminator
import
NLayerDiscriminator
,
NLayerDiscriminator3D
,
weights_init
class
AdaptiveLossWeight
:
def
__init__
(
self
,
timestep_range
=
[
0
,
1
],
buckets
=
300
,
weight_range
=
[
1e-7
,
1e7
]):
self
.
bucket_ranges
=
torch
.
linspace
(
timestep_range
[
0
],
timestep_range
[
1
],
buckets
-
1
)
self
.
bucket_losses
=
torch
.
ones
(
buckets
)
self
.
weight_range
=
weight_range
def
weight
(
self
,
timestep
):
indices
=
torch
.
searchsorted
(
self
.
bucket_ranges
.
to
(
timestep
.
device
),
timestep
)
return
(
1
/
self
.
bucket_losses
.
to
(
timestep
.
device
)[
indices
]).
clamp
(
*
self
.
weight_range
)
def
update_buckets
(
self
,
timestep
,
loss
,
beta
=
0.99
):
indices
=
torch
.
searchsorted
(
self
.
bucket_ranges
.
to
(
timestep
.
device
),
timestep
).
cpu
()
self
.
bucket_losses
[
indices
]
=
self
.
bucket_losses
[
indices
]
*
beta
+
loss
.
detach
().
cpu
()
*
(
1
-
beta
)
def
hinge_d_loss
(
logits_real
,
logits_fake
):
loss_real
=
torch
.
mean
(
F
.
relu
(
1.0
-
logits_real
))
loss_fake
=
torch
.
mean
(
F
.
relu
(
1.0
+
logits_fake
))
d_loss
=
0.5
*
(
loss_real
+
loss_fake
)
return
d_loss
def
vanilla_d_loss
(
logits_real
,
logits_fake
):
d_loss
=
0.5
*
(
torch
.
mean
(
torch
.
nn
.
functional
.
softplus
(
-
logits_real
))
+
torch
.
mean
(
torch
.
nn
.
functional
.
softplus
(
logits_fake
))
)
return
d_loss
def
adopt_weight
(
weight
,
global_step
,
threshold
=
0
,
value
=
0.0
):
if
global_step
<
threshold
:
weight
=
value
return
weight
class
LPIPSWithDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
disc_start
,
logvar_init
=
0.0
,
kl_weight
=
1.0
,
pixelloss_weight
=
1.0
,
perceptual_weight
=
1.0
,
lpips_ckpt
=
'/home/jinyang06/models/vae/video_vae_baseline/vgg_lpips.pth'
,
# --- Discriminator Loss ---
disc_num_layers
=
4
,
disc_in_channels
=
3
,
disc_factor
=
1.0
,
disc_weight
=
0.5
,
disc_loss
=
"hinge"
,
add_discriminator
=
True
,
using_3d_discriminator
=
False
,
):
super
().
__init__
()
assert
disc_loss
in
[
"hinge"
,
"vanilla"
]
self
.
kl_weight
=
kl_weight
self
.
pixel_weight
=
pixelloss_weight
self
.
perceptual_loss
=
LPIPS
(
lpips_ckpt_path
=
lpips_ckpt
).
eval
()
self
.
perceptual_weight
=
perceptual_weight
self
.
logvar
=
nn
.
Parameter
(
torch
.
ones
(
size
=
())
*
logvar_init
)
if
add_discriminator
:
disc_cls
=
NLayerDiscriminator3D
if
using_3d_discriminator
else
NLayerDiscriminator
self
.
discriminator
=
disc_cls
(
input_nc
=
disc_in_channels
,
n_layers
=
disc_num_layers
,
).
apply
(
weights_init
)
else
:
self
.
discriminator
=
None
self
.
discriminator_iter_start
=
disc_start
self
.
disc_loss
=
hinge_d_loss
if
disc_loss
==
"hinge"
else
vanilla_d_loss
self
.
disc_factor
=
disc_factor
self
.
discriminator_weight
=
disc_weight
self
.
using_3d_discriminator
=
using_3d_discriminator
def
calculate_adaptive_weight
(
self
,
nll_loss
,
g_loss
,
last_layer
=
None
):
if
last_layer
is
not
None
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
else
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
d_weight
=
torch
.
norm
(
nll_grads
)
/
(
torch
.
norm
(
g_grads
)
+
1e-4
)
d_weight
=
torch
.
clamp
(
d_weight
,
0.0
,
1e4
).
detach
()
d_weight
=
d_weight
*
self
.
discriminator_weight
return
d_weight
def
forward
(
self
,
inputs
,
reconstructions
,
posteriors
,
optimizer_idx
,
global_step
,
split
=
"train"
,
last_layer
=
None
,
):
t
=
reconstructions
.
shape
[
2
]
inputs
=
rearrange
(
inputs
,
"b c t h w -> (b t) c h w"
).
contiguous
()
reconstructions
=
rearrange
(
reconstructions
,
"b c t h w -> (b t) c h w"
).
contiguous
()
if
optimizer_idx
==
0
:
# rec_loss = torch.mean(torch.abs(inputs - reconstructions), dim=(1,2,3), keepdim=True)
rec_loss
=
torch
.
mean
(
F
.
mse_loss
(
inputs
,
reconstructions
,
reduction
=
'none'
),
dim
=
(
1
,
2
,
3
),
keepdim
=
True
)
if
self
.
perceptual_weight
>
0
:
p_loss
=
self
.
perceptual_loss
(
inputs
,
reconstructions
)
nll_loss
=
self
.
pixel_weight
*
rec_loss
+
self
.
perceptual_weight
*
p_loss
nll_loss
=
nll_loss
/
torch
.
exp
(
self
.
logvar
)
+
self
.
logvar
weighted_nll_loss
=
nll_loss
weighted_nll_loss
=
torch
.
sum
(
weighted_nll_loss
)
/
weighted_nll_loss
.
shape
[
0
]
nll_loss
=
torch
.
sum
(
nll_loss
)
/
nll_loss
.
shape
[
0
]
kl_loss
=
posteriors
.
kl
()
kl_loss
=
torch
.
mean
(
kl_loss
)
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
if
disc_factor
>
0.0
:
if
self
.
using_3d_discriminator
:
reconstructions
=
rearrange
(
reconstructions
,
'(b t) c h w -> b c t h w'
,
t
=
t
)
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
())
g_loss
=
-
torch
.
mean
(
logits_fake
)
try
:
d_weight
=
self
.
calculate_adaptive_weight
(
nll_loss
,
g_loss
,
last_layer
=
last_layer
)
except
RuntimeError
:
assert
not
self
.
training
d_weight
=
torch
.
tensor
(
0.0
)
else
:
d_weight
=
torch
.
tensor
(
0.0
)
g_loss
=
torch
.
tensor
(
0.0
)
loss
=
(
weighted_nll_loss
+
self
.
kl_weight
*
kl_loss
+
d_weight
*
disc_factor
*
g_loss
)
log
=
{
"{}/total_loss"
.
format
(
split
):
loss
.
clone
().
detach
().
mean
(),
"{}/logvar"
.
format
(
split
):
self
.
logvar
.
detach
(),
"{}/kl_loss"
.
format
(
split
):
kl_loss
.
detach
().
mean
(),
"{}/nll_loss"
.
format
(
split
):
nll_loss
.
detach
().
mean
(),
"{}/rec_loss"
.
format
(
split
):
rec_loss
.
detach
().
mean
(),
"{}/perception_loss"
.
format
(
split
):
p_loss
.
detach
().
mean
(),
"{}/d_weight"
.
format
(
split
):
d_weight
.
detach
(),
"{}/disc_factor"
.
format
(
split
):
torch
.
tensor
(
disc_factor
),
"{}/g_loss"
.
format
(
split
):
g_loss
.
detach
().
mean
(),
}
return
loss
,
log
if
optimizer_idx
==
1
:
if
self
.
using_3d_discriminator
:
inputs
=
rearrange
(
inputs
,
'(b t) c h w -> b c t h w'
,
t
=
t
)
reconstructions
=
rearrange
(
reconstructions
,
'(b t) c h w -> b c t h w'
,
t
=
t
)
logits_real
=
self
.
discriminator
(
inputs
.
contiguous
().
detach
())
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
().
detach
())
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
d_loss
=
disc_factor
*
self
.
disc_loss
(
logits_real
,
logits_fake
)
log
=
{
"{}/disc_loss"
.
format
(
split
):
d_loss
.
clone
().
detach
().
mean
(),
"{}/logits_real"
.
format
(
split
):
logits_real
.
detach
().
mean
(),
"{}/logits_fake"
.
format
(
split
):
logits_fake
.
detach
().
mean
(),
}
return
d_loss
,
log
video_vae/modeling_lpips.py
0 → 100644
View file @
0e56f303
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
import
torch
import
torch.nn
as
nn
from
torchvision
import
models
from
collections
import
namedtuple
class
LPIPS
(
nn
.
Module
):
# Learned perceptual metric
def
__init__
(
self
,
use_dropout
=
True
,
lpips_ckpt_path
=
None
):
super
().
__init__
()
self
.
lpips_ckpt_path
=
lpips_ckpt_path
# replace with your lpips path
self
.
scaling_layer
=
ScalingLayer
()
self
.
chns
=
[
64
,
128
,
256
,
512
,
512
]
# vg16 features
self
.
net
=
vgg16
(
pretrained
=
True
,
requires_grad
=
False
)
self
.
lin0
=
NetLinLayer
(
self
.
chns
[
0
],
use_dropout
=
use_dropout
)
self
.
lin1
=
NetLinLayer
(
self
.
chns
[
1
],
use_dropout
=
use_dropout
)
self
.
lin2
=
NetLinLayer
(
self
.
chns
[
2
],
use_dropout
=
use_dropout
)
self
.
lin3
=
NetLinLayer
(
self
.
chns
[
3
],
use_dropout
=
use_dropout
)
self
.
lin4
=
NetLinLayer
(
self
.
chns
[
4
],
use_dropout
=
use_dropout
)
self
.
load_from_pretrained
()
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
load_from_pretrained
(
self
):
ckpt
=
self
.
lpips_ckpt_path
assert
ckpt
is
not
None
,
"Please replace with your lpips path"
self
.
load_state_dict
(
torch
.
load
(
ckpt
,
map_location
=
torch
.
device
(
"cpu"
)),
strict
=
False
)
print
(
"loaded pretrained LPIPS loss from {}"
.
format
(
ckpt
))
def
forward
(
self
,
input
,
target
):
in0_input
,
in1_input
=
(
self
.
scaling_layer
(
input
),
self
.
scaling_layer
(
target
))
outs0
,
outs1
=
self
.
net
(
in0_input
),
self
.
net
(
in1_input
)
feats0
,
feats1
,
diffs
=
{},
{},
{}
lins
=
[
self
.
lin0
,
self
.
lin1
,
self
.
lin2
,
self
.
lin3
,
self
.
lin4
]
for
kk
in
range
(
len
(
self
.
chns
)):
feats0
[
kk
],
feats1
[
kk
]
=
normalize_tensor
(
outs0
[
kk
]),
normalize_tensor
(
outs1
[
kk
])
diffs
[
kk
]
=
(
feats0
[
kk
]
-
feats1
[
kk
])
**
2
res
=
[
spatial_average
(
lins
[
kk
].
model
(
diffs
[
kk
]),
keepdim
=
True
)
for
kk
in
range
(
len
(
self
.
chns
))]
val
=
res
[
0
]
for
l
in
range
(
1
,
len
(
self
.
chns
)):
val
+=
res
[
l
]
return
val
class
ScalingLayer
(
nn
.
Module
):
def
__init__
(
self
):
super
(
ScalingLayer
,
self
).
__init__
()
self
.
register_buffer
(
'shift'
,
torch
.
Tensor
([
-
.
030
,
-
.
088
,
-
.
188
])[
None
,
:,
None
,
None
])
self
.
register_buffer
(
'scale'
,
torch
.
Tensor
([.
458
,
.
448
,
.
450
])[
None
,
:,
None
,
None
])
def
forward
(
self
,
inp
):
return
(
inp
-
self
.
shift
)
/
self
.
scale
class
NetLinLayer
(
nn
.
Module
):
""" A single linear layer which does a 1x1 conv """
def
__init__
(
self
,
chn_in
,
chn_out
=
1
,
use_dropout
=
False
):
super
(
NetLinLayer
,
self
).
__init__
()
layers
=
[
nn
.
Dropout
(),
]
if
(
use_dropout
)
else
[]
layers
+=
[
nn
.
Conv2d
(
chn_in
,
chn_out
,
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
),
]
self
.
model
=
nn
.
Sequential
(
*
layers
)
class
vgg16
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
requires_grad
=
False
,
pretrained
=
True
):
super
(
vgg16
,
self
).
__init__
()
vgg_pretrained_features
=
models
.
vgg16
(
pretrained
=
pretrained
).
features
self
.
slice1
=
torch
.
nn
.
Sequential
()
self
.
slice2
=
torch
.
nn
.
Sequential
()
self
.
slice3
=
torch
.
nn
.
Sequential
()
self
.
slice4
=
torch
.
nn
.
Sequential
()
self
.
slice5
=
torch
.
nn
.
Sequential
()
self
.
N_slices
=
5
for
x
in
range
(
4
):
self
.
slice1
.
add_module
(
str
(
x
),
vgg_pretrained_features
[
x
])
for
x
in
range
(
4
,
9
):
self
.
slice2
.
add_module
(
str
(
x
),
vgg_pretrained_features
[
x
])
for
x
in
range
(
9
,
16
):
self
.
slice3
.
add_module
(
str
(
x
),
vgg_pretrained_features
[
x
])
for
x
in
range
(
16
,
23
):
self
.
slice4
.
add_module
(
str
(
x
),
vgg_pretrained_features
[
x
])
for
x
in
range
(
23
,
30
):
self
.
slice5
.
add_module
(
str
(
x
),
vgg_pretrained_features
[
x
])
if
not
requires_grad
:
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
X
):
h
=
self
.
slice1
(
X
)
h_relu1_2
=
h
h
=
self
.
slice2
(
h
)
h_relu2_2
=
h
h
=
self
.
slice3
(
h
)
h_relu3_3
=
h
h
=
self
.
slice4
(
h
)
h_relu4_3
=
h
h
=
self
.
slice5
(
h
)
h_relu5_3
=
h
vgg_outputs
=
namedtuple
(
"VggOutputs"
,
[
'relu1_2'
,
'relu2_2'
,
'relu3_3'
,
'relu4_3'
,
'relu5_3'
])
out
=
vgg_outputs
(
h_relu1_2
,
h_relu2_2
,
h_relu3_3
,
h_relu4_3
,
h_relu5_3
)
return
out
def
normalize_tensor
(
x
,
eps
=
1e-10
):
norm_factor
=
torch
.
sqrt
(
torch
.
sum
(
x
**
2
,
dim
=
1
,
keepdim
=
True
))
return
x
/
(
norm_factor
+
eps
)
def
spatial_average
(
x
,
keepdim
=
True
):
return
x
.
mean
([
2
,
3
],
keepdim
=
keepdim
)
if
__name__
==
"__main__"
:
model
=
LPIPS
().
eval
()
_
=
torch
.
manual_seed
(
123
)
img1
=
(
torch
.
rand
(
10
,
3
,
100
,
100
)
*
2
)
-
1
img2
=
(
torch
.
rand
(
10
,
3
,
100
,
100
)
*
2
)
-
1
print
(
model
(
img1
,
img2
).
shape
)
# embed()
\ No newline at end of file
video_vae/modeling_resnet.py
0 → 100644
View file @
0e56f303
from
functools
import
partial
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
diffusers.models.activations
import
get_activation
from
diffusers.models.attention_processor
import
SpatialNorm
from
diffusers.models.lora
import
LoRACompatibleConv
,
LoRACompatibleLinear
from
diffusers.models.normalization
import
AdaGroupNorm
from
timm.models.layers
import
drop_path
,
to_2tuple
,
trunc_normal_
from
.modeling_causal_conv
import
CausalConv3d
,
CausalGroupNorm
class
CausalResnetBlock3D
(
nn
.
Module
):
r
"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
groups_out (`int`, *optional*, default to None):
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
"ada_group" for a stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
`conv_shortcut` output.
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
If None, same as `out_channels`.
"""
def
__init__
(
self
,
*
,
in_channels
:
int
,
out_channels
:
Optional
[
int
]
=
None
,
conv_shortcut
:
bool
=
False
,
dropout
:
float
=
0.0
,
temb_channels
:
int
=
512
,
groups
:
int
=
32
,
groups_out
:
Optional
[
int
]
=
None
,
pre_norm
:
bool
=
True
,
eps
:
float
=
1e-6
,
non_linearity
:
str
=
"swish"
,
time_embedding_norm
:
str
=
"default"
,
# default, scale_shift, ada_group, spatial
output_scale_factor
:
float
=
1.0
,
use_in_shortcut
:
Optional
[
bool
]
=
None
,
conv_shortcut_bias
:
bool
=
True
,
conv_2d_out_channels
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
pre_norm
=
pre_norm
self
.
pre_norm
=
True
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
output_scale_factor
=
output_scale_factor
self
.
time_embedding_norm
=
time_embedding_norm
linear_cls
=
nn
.
Linear
if
groups_out
is
None
:
groups_out
=
groups
if
self
.
time_embedding_norm
==
"ada_group"
:
self
.
norm1
=
AdaGroupNorm
(
temb_channels
,
in_channels
,
groups
,
eps
=
eps
)
elif
self
.
time_embedding_norm
==
"spatial"
:
self
.
norm1
=
SpatialNorm
(
in_channels
,
temb_channels
)
else
:
self
.
norm1
=
CausalGroupNorm
(
num_groups
=
groups
,
num_channels
=
in_channels
,
eps
=
eps
,
affine
=
True
)
self
.
conv1
=
CausalConv3d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
)
if
self
.
time_embedding_norm
==
"ada_group"
:
self
.
norm2
=
AdaGroupNorm
(
temb_channels
,
out_channels
,
groups_out
,
eps
=
eps
)
elif
self
.
time_embedding_norm
==
"spatial"
:
self
.
norm2
=
SpatialNorm
(
out_channels
,
temb_channels
)
else
:
self
.
norm2
=
CausalGroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
conv_2d_out_channels
=
conv_2d_out_channels
or
out_channels
self
.
conv2
=
CausalConv3d
(
out_channels
,
conv_2d_out_channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
nonlinearity
=
get_activation
(
non_linearity
)
self
.
upsample
=
self
.
downsample
=
None
self
.
use_in_shortcut
=
self
.
in_channels
!=
conv_2d_out_channels
if
use_in_shortcut
is
None
else
use_in_shortcut
self
.
conv_shortcut
=
None
if
self
.
use_in_shortcut
:
self
.
conv_shortcut
=
CausalConv3d
(
in_channels
,
conv_2d_out_channels
,
kernel_size
=
1
,
stride
=
1
,
bias
=
conv_shortcut_bias
,
)
def
forward
(
self
,
input_tensor
:
torch
.
FloatTensor
,
temb
:
torch
.
FloatTensor
=
None
,
is_init_image
=
True
,
temporal_chunk
=
False
,
)
->
torch
.
FloatTensor
:
hidden_states
=
input_tensor
if
self
.
time_embedding_norm
==
"ada_group"
or
self
.
time_embedding_norm
==
"spatial"
:
hidden_states
=
self
.
norm1
(
hidden_states
,
temb
)
else
:
hidden_states
=
self
.
norm1
(
hidden_states
)
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
hidden_states
=
self
.
conv1
(
hidden_states
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
)
if
temb
is
not
None
and
self
.
time_embedding_norm
==
"default"
:
hidden_states
=
hidden_states
+
temb
if
self
.
time_embedding_norm
==
"ada_group"
or
self
.
time_embedding_norm
==
"spatial"
:
hidden_states
=
self
.
norm2
(
hidden_states
,
temb
)
else
:
hidden_states
=
self
.
norm2
(
hidden_states
)
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
conv2
(
hidden_states
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
)
if
self
.
conv_shortcut
is
not
None
:
input_tensor
=
self
.
conv_shortcut
(
input_tensor
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
)
output_tensor
=
(
input_tensor
+
hidden_states
)
/
self
.
output_scale_factor
return
output_tensor
class
ResnetBlock2D
(
nn
.
Module
):
r
"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
groups_out (`int`, *optional*, default to None):
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
"ada_group" for a stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
`conv_shortcut` output.
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
If None, same as `out_channels`.
"""
def
__init__
(
self
,
*
,
in_channels
:
int
,
out_channels
:
Optional
[
int
]
=
None
,
conv_shortcut
:
bool
=
False
,
dropout
:
float
=
0.0
,
temb_channels
:
int
=
512
,
groups
:
int
=
32
,
groups_out
:
Optional
[
int
]
=
None
,
pre_norm
:
bool
=
True
,
eps
:
float
=
1e-6
,
non_linearity
:
str
=
"swish"
,
time_embedding_norm
:
str
=
"default"
,
# default, scale_shift, ada_group, spatial
output_scale_factor
:
float
=
1.0
,
use_in_shortcut
:
Optional
[
bool
]
=
None
,
conv_shortcut_bias
:
bool
=
True
,
conv_2d_out_channels
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
pre_norm
=
pre_norm
self
.
pre_norm
=
True
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
output_scale_factor
=
output_scale_factor
self
.
time_embedding_norm
=
time_embedding_norm
linear_cls
=
nn
.
Linear
conv_cls
=
nn
.
Conv3d
if
groups_out
is
None
:
groups_out
=
groups
if
self
.
time_embedding_norm
==
"ada_group"
:
self
.
norm1
=
AdaGroupNorm
(
temb_channels
,
in_channels
,
groups
,
eps
=
eps
)
elif
self
.
time_embedding_norm
==
"spatial"
:
self
.
norm1
=
SpatialNorm
(
in_channels
,
temb_channels
)
else
:
self
.
norm1
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups
,
num_channels
=
in_channels
,
eps
=
eps
,
affine
=
True
)
self
.
conv1
=
conv_cls
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
time_embedding_norm
==
"ada_group"
:
self
.
norm2
=
AdaGroupNorm
(
temb_channels
,
out_channels
,
groups_out
,
eps
=
eps
)
elif
self
.
time_embedding_norm
==
"spatial"
:
self
.
norm2
=
SpatialNorm
(
out_channels
,
temb_channels
)
else
:
self
.
norm2
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
conv_2d_out_channels
=
conv_2d_out_channels
or
out_channels
self
.
conv2
=
conv_cls
(
out_channels
,
conv_2d_out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
nonlinearity
=
get_activation
(
non_linearity
)
self
.
upsample
=
self
.
downsample
=
None
self
.
use_in_shortcut
=
self
.
in_channels
!=
conv_2d_out_channels
if
use_in_shortcut
is
None
else
use_in_shortcut
self
.
conv_shortcut
=
None
if
self
.
use_in_shortcut
:
self
.
conv_shortcut
=
conv_cls
(
in_channels
,
conv_2d_out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
conv_shortcut_bias
,
)
def
forward
(
self
,
input_tensor
:
torch
.
FloatTensor
,
temb
:
torch
.
FloatTensor
=
None
,
scale
:
float
=
1.0
,
)
->
torch
.
FloatTensor
:
hidden_states
=
input_tensor
if
self
.
time_embedding_norm
==
"ada_group"
or
self
.
time_embedding_norm
==
"spatial"
:
hidden_states
=
self
.
norm1
(
hidden_states
,
temb
)
else
:
hidden_states
=
self
.
norm1
(
hidden_states
)
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
hidden_states
=
self
.
conv1
(
hidden_states
)
if
temb
is
not
None
and
self
.
time_embedding_norm
==
"default"
:
hidden_states
=
hidden_states
+
temb
if
self
.
time_embedding_norm
==
"ada_group"
or
self
.
time_embedding_norm
==
"spatial"
:
hidden_states
=
self
.
norm2
(
hidden_states
,
temb
)
else
:
hidden_states
=
self
.
norm2
(
hidden_states
)
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
conv2
(
hidden_states
)
if
self
.
conv_shortcut
is
not
None
:
input_tensor
=
self
.
conv_shortcut
(
input_tensor
)
output_tensor
=
(
input_tensor
+
hidden_states
)
/
self
.
output_scale_factor
return
output_tensor
class
CausalDownsample2x
(
nn
.
Module
):
"""A 2D downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def
__init__
(
self
,
channels
:
int
,
use_conv
:
bool
=
True
,
out_channels
:
Optional
[
int
]
=
None
,
name
:
str
=
"conv"
,
kernel_size
=
3
,
bias
=
True
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
stride
=
(
1
,
2
,
2
)
self
.
name
=
name
if
use_conv
:
conv
=
CausalConv3d
(
self
.
channels
,
self
.
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
bias
=
bias
)
else
:
assert
self
.
channels
==
self
.
out_channels
conv
=
nn
.
AvgPool3d
(
kernel_size
=
stride
,
stride
=
stride
)
self
.
conv
=
conv
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
is_init_image
=
True
,
temporal_chunk
=
False
)
->
torch
.
FloatTensor
:
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
hidden_states
=
self
.
conv
(
hidden_states
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
)
return
hidden_states
class
Downsample2D
(
nn
.
Module
):
"""A 2D downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def
__init__
(
self
,
channels
:
int
,
use_conv
:
bool
=
True
,
out_channels
:
Optional
[
int
]
=
None
,
padding
:
int
=
0
,
name
:
str
=
"conv"
,
kernel_size
=
3
,
bias
=
True
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
padding
=
padding
stride
=
(
1
,
2
,
2
)
self
.
name
=
name
conv_cls
=
nn
.
Conv3d
if
use_conv
:
conv
=
conv_cls
(
self
.
channels
,
self
.
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
else
:
assert
self
.
channels
==
self
.
out_channels
conv
=
nn
.
AvgPool2d
(
kernel_size
=
stride
,
stride
=
stride
)
self
.
conv
=
conv
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv
and
self
.
padding
==
0
:
pad
=
(
0
,
1
,
0
,
1
,
1
,
1
)
hidden_states
=
F
.
pad
(
hidden_states
,
pad
,
mode
=
"constant"
,
value
=
0
)
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
hidden_states
=
self
.
conv
(
hidden_states
)
return
hidden_states
class
TemporalDownsample2x
(
nn
.
Module
):
"""A Temporal downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def
__init__
(
self
,
channels
:
int
,
use_conv
:
bool
=
False
,
out_channels
:
Optional
[
int
]
=
None
,
padding
:
int
=
0
,
kernel_size
=
3
,
bias
=
True
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
padding
=
padding
stride
=
(
2
,
1
,
1
)
conv_cls
=
nn
.
Conv3d
if
use_conv
:
conv
=
conv_cls
(
self
.
channels
,
self
.
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
)
else
:
raise
NotImplementedError
(
"Not implemented for temporal downsample without"
)
self
.
conv
=
conv
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv
and
self
.
padding
==
0
:
if
hidden_states
.
shape
[
2
]
==
1
:
# image
pad
=
(
1
,
1
,
1
,
1
,
1
,
1
)
else
:
# video
pad
=
(
1
,
1
,
1
,
1
,
0
,
1
)
hidden_states
=
F
.
pad
(
hidden_states
,
pad
,
mode
=
"constant"
,
value
=
0
)
hidden_states
=
self
.
conv
(
hidden_states
)
return
hidden_states
class
CausalTemporalDownsample2x
(
nn
.
Module
):
"""A Temporal downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def
__init__
(
self
,
channels
:
int
,
use_conv
:
bool
=
False
,
out_channels
:
Optional
[
int
]
=
None
,
kernel_size
=
3
,
bias
=
True
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
stride
=
(
2
,
1
,
1
)
conv_cls
=
nn
.
Conv3d
if
use_conv
:
conv
=
CausalConv3d
(
self
.
channels
,
self
.
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
bias
=
bias
)
else
:
raise
NotImplementedError
(
"Not implemented for temporal downsample without"
)
self
.
conv
=
conv
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
is_init_image
=
True
,
temporal_chunk
=
False
)
->
torch
.
FloatTensor
:
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
hidden_states
=
self
.
conv
(
hidden_states
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
)
return
hidden_states
class
Upsample2D
(
nn
.
Module
):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def
__init__
(
self
,
channels
:
int
,
use_conv
:
bool
=
False
,
out_channels
:
Optional
[
int
]
=
None
,
name
:
str
=
"conv"
,
kernel_size
:
Optional
[
int
]
=
None
,
padding
=
1
,
bias
=
True
,
interpolate
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
name
=
name
self
.
interpolate
=
interpolate
conv_cls
=
nn
.
Conv3d
conv
=
None
if
interpolate
:
raise
NotImplementedError
(
"Not implemented for spatial upsample with interpolate"
)
else
:
if
kernel_size
is
None
:
kernel_size
=
3
conv
=
conv_cls
(
self
.
channels
,
self
.
out_channels
*
4
,
kernel_size
=
kernel_size
,
padding
=
padding
,
bias
=
bias
)
self
.
conv
=
conv
self
.
conv
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
(
nn
.
Linear
,
nn
.
Conv2d
,
nn
.
Conv3d
)):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
)
->
torch
.
FloatTensor
:
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
hidden_states
=
self
.
conv
(
hidden_states
)
hidden_states
=
rearrange
(
hidden_states
,
'b (c p1 p2) t h w -> b c t (h p1) (w p2)'
,
p1
=
2
,
p2
=
2
)
return
hidden_states
class
CausalUpsample2x
(
nn
.
Module
):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def
__init__
(
self
,
channels
:
int
,
use_conv
:
bool
=
False
,
out_channels
:
Optional
[
int
]
=
None
,
name
:
str
=
"conv"
,
kernel_size
:
Optional
[
int
]
=
3
,
bias
=
True
,
interpolate
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
name
=
name
self
.
interpolate
=
interpolate
conv
=
None
if
interpolate
:
raise
NotImplementedError
(
"Not implemented for spatial upsample with interpolate"
)
else
:
conv
=
CausalConv3d
(
self
.
channels
,
self
.
out_channels
*
4
,
kernel_size
=
kernel_size
,
stride
=
1
,
bias
=
bias
)
self
.
conv
=
conv
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
is_init_image
=
True
,
temporal_chunk
=
False
,
)
->
torch
.
FloatTensor
:
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
hidden_states
=
self
.
conv
(
hidden_states
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
)
hidden_states
=
rearrange
(
hidden_states
,
'b (c p1 p2) t h w -> b c t (h p1) (w p2)'
,
p1
=
2
,
p2
=
2
)
return
hidden_states
class
TemporalUpsample2x
(
nn
.
Module
):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def
__init__
(
self
,
channels
:
int
,
use_conv
:
bool
=
True
,
out_channels
:
Optional
[
int
]
=
None
,
kernel_size
:
Optional
[
int
]
=
None
,
padding
=
1
,
bias
=
True
,
interpolate
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
interpolate
=
interpolate
conv_cls
=
nn
.
Conv3d
conv
=
None
if
interpolate
:
raise
NotImplementedError
(
"Not implemented for spatial upsample with interpolate"
)
else
:
# depth to space operator
if
kernel_size
is
None
:
kernel_size
=
3
conv
=
conv_cls
(
self
.
channels
,
self
.
out_channels
*
2
,
kernel_size
=
kernel_size
,
padding
=
padding
,
bias
=
bias
)
self
.
conv
=
conv
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
is_image
:
bool
=
False
,
)
->
torch
.
FloatTensor
:
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
t
=
hidden_states
.
shape
[
2
]
hidden_states
=
self
.
conv
(
hidden_states
)
hidden_states
=
rearrange
(
hidden_states
,
'b (c p) t h w -> b c (p t) h w'
,
p
=
2
)
if
t
==
1
and
is_image
:
hidden_states
=
hidden_states
[:,
:,
1
:]
return
hidden_states
class
CausalTemporalUpsample2x
(
nn
.
Module
):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def
__init__
(
self
,
channels
:
int
,
use_conv
:
bool
=
True
,
out_channels
:
Optional
[
int
]
=
None
,
kernel_size
:
Optional
[
int
]
=
3
,
bias
=
True
,
interpolate
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
interpolate
=
interpolate
conv
=
None
if
interpolate
:
raise
NotImplementedError
(
"Not implemented for spatial upsample with interpolate"
)
else
:
# depth to space operator
conv
=
CausalConv3d
(
self
.
channels
,
self
.
out_channels
*
2
,
kernel_size
=
kernel_size
,
stride
=
1
,
bias
=
bias
)
self
.
conv
=
conv
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
is_init_image
=
True
,
temporal_chunk
=
False
,
)
->
torch
.
FloatTensor
:
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
t
=
hidden_states
.
shape
[
2
]
hidden_states
=
self
.
conv
(
hidden_states
,
is_init_image
=
is_init_image
,
temporal_chunk
=
temporal_chunk
)
hidden_states
=
rearrange
(
hidden_states
,
'b (c p) t h w -> b c (t p) h w'
,
p
=
2
)
if
is_init_image
:
hidden_states
=
hidden_states
[:,
:,
1
:]
return
hidden_states
\ No newline at end of file
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment