Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
Ruyi-Mini-7B
Commits
08a21d59
Commit
08a21d59
authored
Dec 27, 2024
by
chenpangpang
Browse files
feat: 初始提交
parent
1a6b26f1
Pipeline
#2165
failed with stages
in 0 seconds
Changes
95
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
2492 additions
and
0 deletions
+2492
-0
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/attention_processors.py
...s/ruyi/vae/ldm/modules/vaemodules/attention_processors.py
+139
-0
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/common.py
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/common.py
+302
-0
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/discriminator.py
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/discriminator.py
+214
-0
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/down_blocks.py
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/down_blocks.py
+533
-0
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/downsamplers.py
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/downsamplers.py
+148
-0
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/gc_block.py
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/gc_block.py
+79
-0
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/mid_blocks.py
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/mid_blocks.py
+196
-0
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/up_blocks.py
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/up_blocks.py
+395
-0
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/upsamplers.py
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/upsamplers.py
+183
-0
Ruyi-Models/ruyi/vae/ldm/util.py
Ruyi-Models/ruyi/vae/ldm/util.py
+201
-0
Ruyi-Models/ruyi/vae/setup.py
Ruyi-Models/ruyi/vae/setup.py
+13
-0
assets/二维码.jpeg
assets/二维码.jpeg
+0
-0
hf_down.py
hf_down.py
+4
-0
start.sh
start.sh
+4
-0
启动器.ipynb
启动器.ipynb
+81
-0
No files found.
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/attention_processors.py
0 → 100644
View file @
08a21d59
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn.functional
as
F
if
TYPE_CHECKING
:
from
.attention
import
Attention
class
AttnProcessor
:
r
"""
Default processor for performing attention-related computations.
"""
def
__call__
(
self
,
attn
:
"Attention"
,
hidden_states
:
torch
.
FloatTensor
,
encoder_hidden_states
,
attention_mask
,
temb
=
None
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
if
attn
.
spatial_norm
is
not
None
:
hidden_states
=
attn
.
spatial_norm
(
hidden_states
,
temb
=
None
)
# B, L, C
assert
hidden_states
.
ndim
==
3
,
f
"Hidden states must be 3-dimensional, got
{
hidden_states
.
ndim
}
"
batch_size
,
sequence_length
,
_
=
(
hidden_states
.
shape
if
encoder_hidden_states
is
None
else
encoder_hidden_states
.
shape
)
attention_mask
=
attn
.
prepare_attention_mask
(
attention_mask
,
sequence_length
,
batch_size
)
if
attn
.
group_norm
is
not
None
:
hidden_states
=
attn
.
group_norm
(
hidden_states
.
transpose
(
1
,
2
))
hidden_states
=
hidden_states
.
transpose
(
1
,
2
)
query
=
attn
.
to_q
(
hidden_states
)
if
encoder_hidden_states
is
None
:
encoder_hidden_states
=
hidden_states
elif
attn
.
norm_cross
:
encoder_hidden_states
=
attn
.
norm_encoder_hidden_states
(
encoder_hidden_states
)
key
=
attn
.
to_k
(
encoder_hidden_states
)
value
=
attn
.
to_v
(
encoder_hidden_states
)
query
=
attn
.
head_to_batch_dim
(
query
)
key
=
attn
.
head_to_batch_dim
(
key
)
value
=
attn
.
head_to_batch_dim
(
value
)
attention_probs
=
attn
.
get_attention_scores
(
query
,
key
,
attention_mask
)
hidden_states
=
torch
.
bmm
(
attention_probs
,
value
)
hidden_states
=
attn
.
batch_to_head_dim
(
hidden_states
)
hidden_states
=
attn
.
to_out
(
hidden_states
)
hidden_states
=
attn
.
dropout
(
hidden_states
)
if
attn
.
residual_connection
:
hidden_states
=
hidden_states
+
residual
hidden_states
=
hidden_states
/
attn
.
rescale_output_factor
return
hidden_states
class
AttnProcessor2_0
:
r
"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def
__init__
(
self
):
if
not
hasattr
(
F
,
"scaled_dot_product_attention"
):
raise
ImportError
(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def
__call__
(
self
,
attn
:
"Attention"
,
hidden_states
:
torch
.
FloatTensor
,
encoder_hidden_states
,
attention_mask
,
temb
=
None
,
)
->
torch
.
FloatTensor
:
residual
=
hidden_states
if
attn
.
spatial_norm
is
not
None
:
hidden_states
=
attn
.
spatial_norm
(
hidden_states
,
temb
=
None
)
# B, L, C
assert
hidden_states
.
ndim
==
3
,
f
"Hidden states must be 3-dimensional, got
{
hidden_states
.
ndim
}
"
batch_size
,
sequence_length
,
_
=
(
hidden_states
.
shape
if
encoder_hidden_states
is
None
else
encoder_hidden_states
.
shape
)
if
attention_mask
is
not
None
:
attention_mask
=
attn
.
prepare_attention_mask
(
attention_mask
,
sequence_length
,
batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask
=
attention_mask
.
view
(
batch_size
,
attn
.
nheads
,
-
1
,
attention_mask
.
shape
[
-
1
])
if
attn
.
group_norm
is
not
None
:
hidden_states
=
attn
.
group_norm
(
hidden_states
.
transpose
(
1
,
2
))
hidden_states
=
hidden_states
.
transpose
(
1
,
2
)
query
:
torch
.
Tensor
=
attn
.
to_q
(
hidden_states
)
if
encoder_hidden_states
is
None
:
encoder_hidden_states
=
hidden_states
elif
attn
.
norm_cross
:
encoder_hidden_states
=
attn
.
norm_encoder_hidden_states
(
encoder_hidden_states
)
key
:
torch
.
Tensor
=
attn
.
to_k
(
encoder_hidden_states
)
value
:
torch
.
Tensor
=
attn
.
to_v
(
encoder_hidden_states
)
inner_dim
=
key
.
shape
[
-
1
]
head_dim
=
inner_dim
//
attn
.
nheads
query
=
query
.
view
(
batch_size
,
-
1
,
attn
.
nheads
,
head_dim
).
transpose
(
1
,
2
)
key
=
key
.
view
(
batch_size
,
-
1
,
attn
.
nheads
,
head_dim
).
transpose
(
1
,
2
)
value
=
value
.
view
(
batch_size
,
-
1
,
attn
.
nheads
,
head_dim
).
transpose
(
1
,
2
)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
attn_mask
=
attention_mask
,
dropout_p
=
0.0
,
is_causal
=
False
,
scale
=
attn
.
scale
)
hidden_states
=
hidden_states
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
-
1
,
attn
.
nheads
*
head_dim
)
hidden_states
=
hidden_states
.
to
(
query
.
dtype
)
hidden_states
=
attn
.
to_out
(
hidden_states
)
hidden_states
=
attn
.
dropout
(
hidden_states
)
if
attn
.
residual_connection
:
hidden_states
=
hidden_states
+
residual
hidden_states
=
hidden_states
/
attn
.
rescale_output_factor
return
hidden_states
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/common.py
0 → 100755
View file @
08a21d59
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
.activations
import
get_activation
def
cast_tuple
(
t
,
length
=
1
):
return
t
if
isinstance
(
t
,
tuple
)
else
((
t
,)
*
length
)
def
divisible_by
(
num
,
den
):
return
(
num
%
den
)
==
0
def
is_odd
(
n
):
return
not
divisible_by
(
n
,
2
)
class
CausalConv3d
(
nn
.
Conv3d
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
=
3
,
# : int | tuple[int, int, int],
stride
=
1
,
# : int | tuple[int, int, int] = 1,
padding
=
1
,
# : int | tuple[int, int, int], # TODO: change it to 0.
dilation
=
1
,
# : int | tuple[int, int, int] = 1,
**
kwargs
,
):
kernel_size
=
kernel_size
if
isinstance
(
kernel_size
,
tuple
)
else
(
kernel_size
,)
*
3
assert
len
(
kernel_size
)
==
3
,
f
"Kernel size must be a 3-tuple, got
{
kernel_size
}
instead."
stride
=
stride
if
isinstance
(
stride
,
tuple
)
else
(
stride
,)
*
3
assert
len
(
stride
)
==
3
,
f
"Stride must be a 3-tuple, got
{
stride
}
instead."
dilation
=
dilation
if
isinstance
(
dilation
,
tuple
)
else
(
dilation
,)
*
3
assert
len
(
dilation
)
==
3
,
f
"Dilation must be a 3-tuple, got
{
dilation
}
instead."
t_ks
,
h_ks
,
w_ks
=
kernel_size
self
.
t_stride
,
h_stride
,
w_stride
=
stride
t_dilation
,
h_dilation
,
w_dilation
=
dilation
t_pad
=
(
t_ks
-
1
)
*
t_dilation
# TODO: align with SD
if
padding
is
None
:
h_pad
=
math
.
ceil
(((
h_ks
-
1
)
*
h_dilation
+
(
1
-
h_stride
))
/
2
)
w_pad
=
math
.
ceil
(((
w_ks
-
1
)
*
w_dilation
+
(
1
-
w_stride
))
/
2
)
elif
isinstance
(
padding
,
int
):
h_pad
=
w_pad
=
padding
else
:
assert
NotImplementedError
self
.
temporal_padding
=
t_pad
self
.
temporal_padding_origin
=
math
.
ceil
(((
t_ks
-
1
)
*
w_dilation
+
(
1
-
w_stride
))
/
2
)
self
.
padding_flag
=
0
self
.
prev_features
=
None
super
().
__init__
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
dilation
=
dilation
,
padding
=
(
0
,
h_pad
,
w_pad
),
**
kwargs
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# x: (B, C, T, H, W)
dtype
=
x
.
dtype
x
=
x
.
float
()
if
self
.
padding_flag
==
0
:
x
=
F
.
pad
(
x
,
pad
=
(
0
,
0
,
0
,
0
,
self
.
temporal_padding
,
0
),
mode
=
"replicate"
,
# TODO: check if this is necessary
)
x
=
x
.
to
(
dtype
=
dtype
)
return
super
().
forward
(
x
)
elif
self
.
padding_flag
==
5
:
x
=
F
.
pad
(
x
,
pad
=
(
0
,
0
,
0
,
0
,
self
.
temporal_padding
,
0
),
mode
=
"replicate"
,
# TODO: check if this is necessary
)
x
=
x
.
to
(
dtype
=
dtype
)
self
.
prev_features
=
x
[:,
:,
-
self
.
temporal_padding
:]
return
super
().
forward
(
x
)
elif
self
.
padding_flag
==
6
:
if
self
.
t_stride
==
2
:
x
=
torch
.
concat
(
[
self
.
prev_features
[:,
:,
-
(
self
.
temporal_padding
-
1
):],
x
],
dim
=
2
)
else
:
x
=
torch
.
concat
(
[
self
.
prev_features
,
x
],
dim
=
2
)
self
.
prev_features
=
x
[:,
:,
-
self
.
temporal_padding
:]
x
=
x
.
to
(
dtype
=
dtype
)
return
super
().
forward
(
x
)
else
:
x
=
F
.
pad
(
x
,
pad
=
(
0
,
0
,
0
,
0
,
self
.
temporal_padding_origin
,
self
.
temporal_padding_origin
),
)
x
=
x
.
to
(
dtype
=
dtype
)
return
super
().
forward
(
x
)
class
ResidualBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
non_linearity
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
output_scale_factor
:
float
=
1.0
,
):
super
().
__init__
()
self
.
output_scale_factor
=
output_scale_factor
self
.
norm1
=
nn
.
GroupNorm
(
num_groups
=
norm_num_groups
,
num_channels
=
in_channels
,
eps
=
norm_eps
,
affine
=
True
,
)
self
.
nonlinearity
=
get_activation
(
non_linearity
)
self
.
conv1
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
padding
=
1
)
self
.
norm2
=
nn
.
GroupNorm
(
num_groups
=
norm_num_groups
,
num_channels
=
out_channels
,
eps
=
norm_eps
,
affine
=
True
,
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
conv2
=
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
padding
=
1
)
if
in_channels
!=
out_channels
:
self
.
shortcut
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
)
else
:
self
.
shortcut
=
nn
.
Identity
()
self
.
set_3dgroupnorm
=
False
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shortcut
=
self
.
shortcut
(
x
)
if
self
.
set_3dgroupnorm
:
batch_size
=
x
.
shape
[
0
]
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
x
=
self
.
norm1
(
x
)
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
b
=
batch_size
)
else
:
x
=
self
.
norm1
(
x
)
x
=
self
.
nonlinearity
(
x
)
x
=
self
.
conv1
(
x
)
if
self
.
set_3dgroupnorm
:
batch_size
=
x
.
shape
[
0
]
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
x
=
self
.
norm2
(
x
)
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
b
=
batch_size
)
else
:
x
=
self
.
norm2
(
x
)
x
=
self
.
nonlinearity
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
conv2
(
x
)
return
(
x
+
shortcut
)
/
self
.
output_scale_factor
class
ResidualBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
non_linearity
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
output_scale_factor
:
float
=
1.0
,
):
super
().
__init__
()
self
.
output_scale_factor
=
output_scale_factor
self
.
norm1
=
nn
.
GroupNorm
(
num_groups
=
norm_num_groups
,
num_channels
=
in_channels
,
eps
=
norm_eps
,
affine
=
True
,
)
self
.
nonlinearity
=
get_activation
(
non_linearity
)
self
.
conv1
=
CausalConv3d
(
in_channels
,
out_channels
,
kernel_size
=
3
)
self
.
norm2
=
nn
.
GroupNorm
(
num_groups
=
norm_num_groups
,
num_channels
=
out_channels
,
eps
=
norm_eps
,
affine
=
True
,
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
conv2
=
CausalConv3d
(
out_channels
,
out_channels
,
kernel_size
=
3
)
if
in_channels
!=
out_channels
:
self
.
shortcut
=
nn
.
Conv3d
(
in_channels
,
out_channels
,
kernel_size
=
1
)
else
:
self
.
shortcut
=
nn
.
Identity
()
self
.
set_3dgroupnorm
=
False
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shortcut
=
self
.
shortcut
(
x
)
if
self
.
set_3dgroupnorm
:
batch_size
=
x
.
shape
[
0
]
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
x
=
self
.
norm1
(
x
)
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
b
=
batch_size
)
else
:
x
=
self
.
norm1
(
x
)
x
=
self
.
nonlinearity
(
x
)
x
=
self
.
conv1
(
x
)
if
self
.
set_3dgroupnorm
:
batch_size
=
x
.
shape
[
0
]
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
x
=
self
.
norm2
(
x
)
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
b
=
batch_size
)
else
:
x
=
self
.
norm2
(
x
)
x
=
self
.
nonlinearity
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
conv2
(
x
)
return
(
x
+
shortcut
)
/
self
.
output_scale_factor
class
SpatialNorm2D
(
nn
.
Module
):
"""
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
Args:
f_channels (`int`):
The number of channels for input to group normalization layer, and output of the spatial norm layer.
zq_channels (`int`):
The number of channels for the quantized vector as described in the paper.
"""
def
__init__
(
self
,
f_channels
:
int
,
zq_channels
:
int
,
):
super
().
__init__
()
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
f_channels
,
num_groups
=
32
,
eps
=
1e-6
,
affine
=
True
)
self
.
conv_y
=
nn
.
Conv2d
(
zq_channels
,
f_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
conv_b
=
nn
.
Conv2d
(
zq_channels
,
f_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
set_3dgroupnorm
=
False
def
forward
(
self
,
f
:
torch
.
FloatTensor
,
zq
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
f_size
=
f
.
shape
[
-
2
:]
zq
=
F
.
interpolate
(
zq
,
size
=
f_size
,
mode
=
"nearest"
)
if
self
.
set_3dgroupnorm
:
batch_size
=
f
.
shape
[
0
]
f
=
rearrange
(
f
,
"b c t h w -> (b t) c h w"
)
norm_f
=
self
.
norm
(
f
)
norm_f
=
rearrange
(
norm_f
,
"(b t) c h w -> b c t h w"
,
b
=
batch_size
)
else
:
norm_f
=
self
.
norm
(
f
)
new_f
=
norm_f
*
self
.
conv_y
(
zq
)
+
self
.
conv_b
(
zq
)
return
new_f
class
SpatialNorm3D
(
SpatialNorm2D
):
def
forward
(
self
,
f
:
torch
.
FloatTensor
,
zq
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
batch_size
=
f
.
shape
[
0
]
f
=
rearrange
(
f
,
"b c t h w -> (b t) c h w"
)
zq
=
rearrange
(
zq
,
"b c t h w -> (b t) c h w"
)
x
=
super
().
forward
(
f
,
zq
)
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
b
=
batch_size
)
return
x
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/discriminator.py
0 → 100644
View file @
08a21d59
import
math
import
torch
import
torch.nn
as
nn
from
.downsamplers
import
BlurPooling2D
,
BlurPooling3D
class
DiscriminatorBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
dropout
:
float
=
0.0
,
output_scale_factor
:
float
=
1.0
,
add_downsample
:
bool
=
True
,
):
super
().
__init__
()
self
.
output_scale_factor
=
output_scale_factor
self
.
norm1
=
nn
.
BatchNorm2d
(
in_channels
)
self
.
nonlinearity
=
nn
.
LeakyReLU
(
0.2
)
self
.
conv1
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
padding
=
1
)
if
add_downsample
:
self
.
downsampler
=
BlurPooling2D
(
out_channels
,
out_channels
)
else
:
self
.
downsampler
=
nn
.
Identity
()
self
.
norm2
=
nn
.
BatchNorm2d
(
out_channels
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
conv2
=
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
padding
=
1
)
if
add_downsample
:
self
.
shortcut
=
nn
.
Sequential
(
BlurPooling2D
(
in_channels
,
in_channels
),
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
),
)
else
:
self
.
shortcut
=
nn
.
Identity
()
self
.
spatial_downsample_factor
=
2
self
.
temporal_downsample_factor
=
1
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shortcut
=
self
.
shortcut
(
x
)
x
=
self
.
norm1
(
x
)
x
=
self
.
nonlinearity
(
x
)
x
=
self
.
conv1
(
x
)
x
=
self
.
norm2
(
x
)
x
=
self
.
nonlinearity
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
downsampler
(
x
)
x
=
self
.
conv2
(
x
)
return
(
x
+
shortcut
)
/
self
.
output_scale_factor
class
Discriminator2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
=
3
,
block_out_channels
=
(
64
,),
):
super
().
__init__
()
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
block_out_channels
[
0
],
kernel_size
=
3
,
padding
=
1
)
self
.
blocks
=
nn
.
ModuleList
([])
output_channels
=
block_out_channels
[
0
]
for
i
,
out_channels
in
enumerate
(
block_out_channels
):
input_channels
=
output_channels
output_channels
=
out_channels
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
self
.
blocks
.
append
(
DiscriminatorBlock2D
(
in_channels
=
input_channels
,
out_channels
=
output_channels
,
output_scale_factor
=
math
.
sqrt
(
2
),
add_downsample
=
not
is_final_block
,
)
)
self
.
conv_norm_out
=
nn
.
BatchNorm2d
(
block_out_channels
[
-
1
])
self
.
conv_act
=
nn
.
LeakyReLU
(
0.2
)
self
.
conv_out
=
nn
.
Conv2d
(
block_out_channels
[
-
1
],
1
,
kernel_size
=
3
,
padding
=
1
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# x: (B, C, H, W)
x
=
self
.
conv_in
(
x
)
for
block
in
self
.
blocks
:
x
=
block
(
x
)
x
=
self
.
conv_out
(
x
)
return
x
class
DiscriminatorBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
dropout
:
float
=
0.0
,
output_scale_factor
:
float
=
1.0
,
add_downsample
:
bool
=
True
,
):
super
().
__init__
()
self
.
output_scale_factor
=
output_scale_factor
self
.
norm1
=
nn
.
GroupNorm
(
32
,
in_channels
)
self
.
nonlinearity
=
nn
.
LeakyReLU
(
0.2
)
self
.
conv1
=
nn
.
Conv3d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
padding
=
1
)
if
add_downsample
:
self
.
downsampler
=
BlurPooling3D
(
out_channels
,
out_channels
)
else
:
self
.
downsampler
=
nn
.
Identity
()
self
.
norm2
=
nn
.
GroupNorm
(
32
,
out_channels
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
conv2
=
nn
.
Conv3d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
padding
=
1
)
if
add_downsample
:
self
.
shortcut
=
nn
.
Sequential
(
BlurPooling3D
(
in_channels
,
in_channels
),
nn
.
Conv3d
(
in_channels
,
out_channels
,
kernel_size
=
1
),
)
else
:
self
.
shortcut
=
nn
.
Sequential
(
nn
.
Conv3d
(
in_channels
,
out_channels
,
kernel_size
=
1
),
)
self
.
spatial_downsample_factor
=
2
self
.
temporal_downsample_factor
=
2
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shortcut
=
self
.
shortcut
(
x
)
x
=
self
.
norm1
(
x
)
x
=
self
.
nonlinearity
(
x
)
x
=
self
.
conv1
(
x
)
x
=
self
.
norm2
(
x
)
x
=
self
.
nonlinearity
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
downsampler
(
x
)
x
=
self
.
conv2
(
x
)
return
(
x
+
shortcut
)
/
self
.
output_scale_factor
class
Discriminator3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
=
3
,
block_out_channels
=
(
64
,),
):
super
().
__init__
()
self
.
conv_in
=
nn
.
Conv3d
(
in_channels
,
block_out_channels
[
0
],
kernel_size
=
3
,
padding
=
1
,
stride
=
2
)
self
.
blocks
=
nn
.
ModuleList
([])
output_channels
=
block_out_channels
[
0
]
for
i
,
out_channels
in
enumerate
(
block_out_channels
):
input_channels
=
output_channels
output_channels
=
out_channels
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
self
.
blocks
.
append
(
DiscriminatorBlock3D
(
in_channels
=
input_channels
,
out_channels
=
output_channels
,
output_scale_factor
=
math
.
sqrt
(
2
),
add_downsample
=
not
is_final_block
,
)
)
self
.
conv_norm_out
=
nn
.
GroupNorm
(
32
,
block_out_channels
[
-
1
])
self
.
conv_act
=
nn
.
LeakyReLU
(
0.2
)
self
.
conv_out
=
nn
.
Conv3d
(
block_out_channels
[
-
1
],
1
,
kernel_size
=
3
,
padding
=
1
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# x: (B, C, T, H, W)
x
=
self
.
conv_in
(
x
)
for
block
in
self
.
blocks
:
x
=
block
(
x
)
x
=
self
.
conv_out
(
x
)
return
x
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/down_blocks.py
0 → 100755
View file @
08a21d59
import
torch
import
torch.nn
as
nn
from
.attention
import
SpatialAttention
,
TemporalAttention
from
.common
import
ResidualBlock3D
from
.downsamplers
import
(
SpatialDownsampler3D
,
SpatialTemporalDownsampler3D
,
TemporalDownsampler3D
)
from
.gc_block
import
GlobalContextBlock
def
get_down_block
(
down_block_type
:
str
,
in_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
,
act_fn
:
str
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
num_attention_heads
:
int
=
1
,
output_scale_factor
:
float
=
1.0
,
add_gc_block
:
bool
=
False
,
add_downsample
:
bool
=
True
,
)
->
nn
.
Module
:
if
down_block_type
==
"DownBlock3D"
:
return
DownBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
num_layers
=
num_layers
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
add_gc_block
=
add_gc_block
,
)
elif
down_block_type
==
"SpatialDownBlock3D"
:
return
SpatialDownBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
num_layers
=
num_layers
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
add_gc_block
=
add_gc_block
,
add_downsample
=
add_downsample
,
)
elif
down_block_type
==
"SpatialAttnDownBlock3D"
:
return
SpatialAttnDownBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
num_layers
=
num_layers
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
attention_head_dim
=
out_channels
//
num_attention_heads
,
output_scale_factor
=
output_scale_factor
,
add_gc_block
=
add_gc_block
,
add_downsample
=
add_downsample
,
)
elif
down_block_type
==
"TemporalDownBlock3D"
:
return
TemporalDownBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
num_layers
=
num_layers
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
add_gc_block
=
add_gc_block
,
add_downsample
=
add_downsample
,
)
elif
down_block_type
==
"TemporalAttnDownBlock3D"
:
return
TemporalAttnDownBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
num_layers
=
num_layers
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
attention_head_dim
=
out_channels
//
num_attention_heads
,
output_scale_factor
=
output_scale_factor
,
add_gc_block
=
add_gc_block
,
add_downsample
=
add_downsample
,
)
elif
down_block_type
==
"SpatialTemporalDownBlock3D"
:
return
SpatialTemporalDownBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
num_layers
=
num_layers
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
add_gc_block
=
add_gc_block
,
add_downsample
=
add_downsample
,
)
else
:
raise
ValueError
(
f
"Unknown down block type:
{
down_block_type
}
"
)
class
DownBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
output_scale_factor
:
float
=
1.0
,
add_gc_block
:
bool
=
False
,
):
super
().
__init__
()
self
.
convs
=
nn
.
ModuleList
([])
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
convs
.
append
(
ResidualBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
non_linearity
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
)
)
if
add_gc_block
:
self
.
gc_block
=
GlobalContextBlock
(
out_channels
,
out_channels
,
fusion_type
=
"mul"
)
else
:
self
.
gc_block
=
None
self
.
spatial_downsample_factor
=
1
self
.
temporal_downsample_factor
=
1
def
forward
(
self
,
x
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
for
conv
in
self
.
convs
:
x
=
conv
(
x
)
if
self
.
gc_block
is
not
None
:
x
=
self
.
gc_block
(
x
)
return
x
class
SpatialDownBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
output_scale_factor
:
float
=
1.0
,
add_gc_block
:
bool
=
False
,
add_downsample
:
bool
=
True
,
):
super
().
__init__
()
self
.
convs
=
nn
.
ModuleList
([])
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
convs
.
append
(
ResidualBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
non_linearity
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
)
)
if
add_gc_block
:
self
.
gc_block
=
GlobalContextBlock
(
out_channels
,
out_channels
,
fusion_type
=
"mul"
)
else
:
self
.
gc_block
=
None
if
add_downsample
:
self
.
downsampler
=
SpatialDownsampler3D
(
out_channels
,
out_channels
)
self
.
spatial_downsample_factor
=
2
else
:
self
.
downsampler
=
None
self
.
spatial_downsample_factor
=
1
self
.
temporal_downsample_factor
=
1
def
forward
(
self
,
x
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
for
conv
in
self
.
convs
:
x
=
conv
(
x
)
if
self
.
gc_block
is
not
None
:
x
=
self
.
gc_block
(
x
)
if
self
.
downsampler
is
not
None
:
x
=
self
.
downsampler
(
x
)
return
x
class
TemporalDownBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
output_scale_factor
:
float
=
1.0
,
add_gc_block
:
bool
=
False
,
add_downsample
:
bool
=
True
,
):
super
().
__init__
()
self
.
convs
=
nn
.
ModuleList
([])
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
convs
.
append
(
ResidualBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
non_linearity
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
)
)
if
add_gc_block
:
self
.
gc_block
=
GlobalContextBlock
(
out_channels
,
out_channels
,
fusion_type
=
"mul"
)
else
:
self
.
gc_block
=
None
if
add_downsample
:
self
.
downsampler
=
TemporalDownsampler3D
(
out_channels
,
out_channels
)
self
.
temporal_downsample_factor
=
2
else
:
self
.
downsampler
=
None
self
.
temporal_downsample_factor
=
1
self
.
spatial_downsample_factor
=
1
def
forward
(
self
,
x
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
for
conv
in
self
.
convs
:
x
=
conv
(
x
)
if
self
.
gc_block
is
not
None
:
x
=
self
.
gc_block
(
x
)
if
self
.
downsampler
is
not
None
:
x
=
self
.
downsampler
(
x
)
return
x
class
SpatialTemporalDownBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
output_scale_factor
:
float
=
1.0
,
add_gc_block
:
bool
=
False
,
add_downsample
:
bool
=
True
,
):
super
().
__init__
()
self
.
convs
=
nn
.
ModuleList
([])
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
convs
.
append
(
ResidualBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
non_linearity
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
)
)
if
add_gc_block
:
self
.
gc_block
=
GlobalContextBlock
(
out_channels
,
out_channels
,
fusion_type
=
"mul"
)
else
:
self
.
gc_block
=
None
if
add_downsample
:
self
.
downsampler
=
SpatialTemporalDownsampler3D
(
out_channels
,
out_channels
)
self
.
spatial_downsample_factor
=
2
self
.
temporal_downsample_factor
=
2
else
:
self
.
downsampler
=
None
self
.
spatial_downsample_factor
=
1
self
.
temporal_downsample_factor
=
1
def
forward
(
self
,
x
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
for
conv
in
self
.
convs
:
x
=
conv
(
x
)
if
self
.
gc_block
is
not
None
:
x
=
self
.
gc_block
(
x
)
if
self
.
downsampler
is
not
None
:
x
=
self
.
downsampler
(
x
)
return
x
class
SpatialAttnDownBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
attention_head_dim
:
int
=
1
,
output_scale_factor
:
float
=
1.0
,
add_gc_block
:
bool
=
False
,
add_downsample
:
bool
=
True
,
):
super
().
__init__
()
self
.
convs
=
nn
.
ModuleList
([])
self
.
attentions
=
nn
.
ModuleList
([])
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
convs
.
append
(
ResidualBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
non_linearity
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
)
)
self
.
attentions
.
append
(
SpatialAttention
(
out_channels
,
nheads
=
out_channels
//
attention_head_dim
,
head_dim
=
attention_head_dim
,
bias
=
True
,
upcast_softmax
=
True
,
norm_num_groups
=
norm_num_groups
,
eps
=
norm_eps
,
rescale_output_factor
=
output_scale_factor
,
residual_connection
=
True
,
)
)
if
add_gc_block
:
self
.
gc_block
=
GlobalContextBlock
(
out_channels
,
out_channels
,
fusion_type
=
"mul"
)
else
:
self
.
gc_block
=
None
if
add_downsample
:
self
.
downsampler
=
SpatialDownsampler3D
(
out_channels
,
out_channels
)
self
.
spatial_downsample_factor
=
2
else
:
self
.
downsampler
=
None
self
.
spatial_downsample_factor
=
1
self
.
temporal_downsample_factor
=
1
def
forward
(
self
,
x
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
for
conv
,
attn
in
zip
(
self
.
convs
,
self
.
attentions
):
x
=
conv
(
x
)
x
=
attn
(
x
)
if
self
.
gc_block
is
not
None
:
x
=
self
.
gc_block
(
x
)
if
self
.
downsampler
is
not
None
:
x
=
self
.
downsampler
(
x
)
return
x
class
TemporalDownBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
output_scale_factor
:
float
=
1.0
,
add_gc_block
:
bool
=
False
,
add_downsample
:
bool
=
True
,
):
super
().
__init__
()
self
.
convs
=
nn
.
ModuleList
([])
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
convs
.
append
(
ResidualBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
non_linearity
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
)
)
if
add_gc_block
:
self
.
gc_block
=
GlobalContextBlock
(
out_channels
,
out_channels
,
fusion_type
=
"mul"
)
else
:
self
.
gc_block
=
None
if
add_downsample
:
self
.
downsampler
=
TemporalDownsampler3D
(
out_channels
,
out_channels
)
self
.
temporal_downsample_factor
=
2
else
:
self
.
downsampler
=
None
self
.
temporal_downsample_factor
=
1
self
.
spatial_downsample_factor
=
1
def
forward
(
self
,
x
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
for
conv
in
self
.
convs
:
x
=
conv
(
x
)
if
self
.
gc_block
is
not
None
:
x
=
self
.
gc_block
(
x
)
if
self
.
downsampler
is
not
None
:
x
=
self
.
downsampler
(
x
)
return
x
class
TemporalAttnDownBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
attention_head_dim
:
int
=
1
,
output_scale_factor
:
float
=
1.0
,
add_gc_block
:
bool
=
False
,
add_downsample
:
bool
=
True
,
):
super
().
__init__
()
self
.
convs
=
nn
.
ModuleList
([])
self
.
attentions
=
nn
.
ModuleList
([])
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
convs
.
append
(
ResidualBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
non_linearity
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
)
)
self
.
attentions
.
append
(
TemporalAttention
(
out_channels
,
nheads
=
out_channels
//
attention_head_dim
,
head_dim
=
attention_head_dim
,
bias
=
True
,
upcast_softmax
=
True
,
norm_num_groups
=
norm_num_groups
,
eps
=
norm_eps
,
rescale_output_factor
=
output_scale_factor
,
residual_connection
=
True
,
)
)
if
add_gc_block
:
self
.
gc_block
=
GlobalContextBlock
(
out_channels
,
out_channels
,
fusion_type
=
"mul"
)
else
:
self
.
gc_block
=
None
if
add_downsample
:
self
.
downsampler
=
TemporalDownsampler3D
(
out_channels
,
out_channels
)
self
.
temporal_downsample_factor
=
2
else
:
self
.
downsampler
=
None
self
.
temporal_downsample_factor
=
1
self
.
spatial_downsample_factor
=
1
def
forward
(
self
,
x
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
for
conv
,
attn
in
zip
(
self
.
convs
,
self
.
attentions
):
x
=
conv
(
x
)
x
=
attn
(
x
)
if
self
.
gc_block
is
not
None
:
x
=
self
.
gc_block
(
x
)
if
self
.
downsampler
is
not
None
:
x
=
self
.
downsampler
(
x
)
return
x
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/downsamplers.py
0 → 100644
View file @
08a21d59
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.common
import
CausalConv3d
class
Downsampler
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
spatial_downsample_factor
:
int
=
1
,
temporal_downsample_factor
:
int
=
1
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
spatial_downsample_factor
=
spatial_downsample_factor
self
.
temporal_downsample_factor
=
temporal_downsample_factor
class
SpatialDownsampler3D
(
Downsampler
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
):
if
out_channels
is
None
:
out_channels
=
in_channels
super
().
__init__
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
spatial_downsample_factor
=
2
,
temporal_downsample_factor
=
1
,
)
self
.
conv
=
CausalConv3d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
(
1
,
2
,
2
),
padding
=
0
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
F
.
pad
(
x
,
(
0
,
1
,
0
,
1
))
return
self
.
conv
(
x
)
class
TemporalDownsampler3D
(
Downsampler
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
):
if
out_channels
is
None
:
out_channels
=
in_channels
super
().
__init__
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
spatial_downsample_factor
=
1
,
temporal_downsample_factor
=
2
,
)
self
.
conv
=
CausalConv3d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
(
2
,
1
,
1
),
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
conv
(
x
)
class
SpatialTemporalDownsampler3D
(
Downsampler
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
):
if
out_channels
is
None
:
out_channels
=
in_channels
super
().
__init__
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
spatial_downsample_factor
=
2
,
temporal_downsample_factor
=
2
,
)
self
.
conv
=
CausalConv3d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
(
2
,
2
,
2
),
padding
=
0
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
F
.
pad
(
x
,
(
0
,
1
,
0
,
1
))
return
self
.
conv
(
x
)
class
BlurPooling2D
(
Downsampler
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
):
if
out_channels
is
None
:
out_channels
=
in_channels
assert
in_channels
==
out_channels
super
().
__init__
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
spatial_downsample_factor
=
2
,
temporal_downsample_factor
=
1
,
)
filt
=
torch
.
tensor
([
1
,
2
,
1
],
dtype
=
torch
.
float32
)
filt
=
torch
.
einsum
(
"i,j -> ij"
,
filt
,
filt
)
filt
=
filt
/
filt
.
sum
()
filt
=
filt
[
None
,
None
].
repeat
(
out_channels
,
1
,
1
,
1
)
self
.
register_buffer
(
"filt"
,
filt
)
self
.
filt
:
torch
.
Tensor
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# x: (B, C, H, W)
return
F
.
conv2d
(
x
,
self
.
filt
,
stride
=
2
,
padding
=
1
,
groups
=
self
.
in_channels
)
class
BlurPooling3D
(
Downsampler
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
):
if
out_channels
is
None
:
out_channels
=
in_channels
assert
in_channels
==
out_channels
super
().
__init__
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
spatial_downsample_factor
=
2
,
temporal_downsample_factor
=
2
,
)
filt
=
torch
.
tensor
([
1
,
2
,
1
],
dtype
=
torch
.
float32
)
filt
=
torch
.
einsum
(
"i,j,k -> ijk"
,
filt
,
filt
,
filt
)
filt
=
filt
/
filt
.
sum
()
filt
=
filt
[
None
,
None
].
repeat
(
out_channels
,
1
,
1
,
1
,
1
)
self
.
register_buffer
(
"filt"
,
filt
)
self
.
filt
:
torch
.
Tensor
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# x: (B, C, T, H, W)
return
F
.
conv3d
(
x
,
self
.
filt
,
stride
=
2
,
padding
=
1
,
groups
=
self
.
in_channels
)
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/gc_block.py
0 → 100644
View file @
08a21d59
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
class
GlobalContextBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
min_channels
:
int
=
16
,
init_bias
:
float
=
-
10.
,
fusion_type
:
str
=
"mul"
,
):
super
().
__init__
()
assert
fusion_type
in
(
"mul"
,
"add"
),
f
"Unsupported fusion type:
{
fusion_type
}
"
self
.
fusion_type
=
fusion_type
self
.
conv_ctx
=
nn
.
Conv2d
(
in_channels
,
1
,
kernel_size
=
1
)
num_channels
=
max
(
min_channels
,
out_channels
//
2
)
if
fusion_type
==
"mul"
:
self
.
conv_mul
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
num_channels
,
kernel_size
=
1
),
nn
.
LayerNorm
([
num_channels
,
1
,
1
]),
# TODO: LayerNorm or GroupNorm?
nn
.
LeakyReLU
(
0.1
),
nn
.
Conv2d
(
num_channels
,
out_channels
,
kernel_size
=
1
),
nn
.
Sigmoid
(),
)
nn
.
init
.
zeros_
(
self
.
conv_mul
[
-
2
].
weight
)
nn
.
init
.
constant_
(
self
.
conv_mul
[
-
2
].
bias
,
init_bias
)
else
:
self
.
conv_add
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
num_channels
,
kernel_size
=
1
),
nn
.
LayerNorm
([
num_channels
,
1
,
1
]),
# TODO: LayerNorm or GroupNorm?
nn
.
LeakyReLU
(
0.1
),
nn
.
Conv2d
(
num_channels
,
out_channels
,
kernel_size
=
1
),
)
nn
.
init
.
zeros_
(
self
.
conv_add
[
-
1
].
weight
)
nn
.
init
.
constant_
(
self
.
conv_add
[
-
1
].
bias
,
init_bias
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
is_image
=
x
.
ndim
==
4
if
is_image
:
x
=
rearrange
(
x
,
"b c h w -> b c 1 h w"
)
# x: (B, C, T, H, W)
orig_x
=
x
batch_size
=
x
.
shape
[
0
]
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
ctx
=
self
.
conv_ctx
(
x
)
ctx
=
rearrange
(
ctx
,
"b c h w -> b c (h w)"
)
ctx
=
F
.
softmax
(
ctx
,
dim
=-
1
)
flattened_x
=
rearrange
(
x
,
"b c h w -> b c (h w)"
)
x
=
torch
.
einsum
(
"b c1 n, b c2 n -> b c2 c1"
,
ctx
,
flattened_x
)
x
=
rearrange
(
x
,
"... -> ... 1"
)
if
self
.
fusion_type
==
"mul"
:
mul_term
=
self
.
conv_mul
(
x
)
mul_term
=
rearrange
(
mul_term
,
"(b t) c h w -> b c t h w"
,
b
=
batch_size
)
x
=
orig_x
*
mul_term
else
:
add_term
=
self
.
conv_add
(
x
)
add_term
=
rearrange
(
add_term
,
"(b t) c h w -> b c t h w"
,
b
=
batch_size
)
x
=
orig_x
+
add_term
if
is_image
:
x
=
rearrange
(
x
,
"b c 1 h w -> b c h w"
)
return
x
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/mid_blocks.py
0 → 100755
View file @
08a21d59
import
torch
import
torch.nn
as
nn
from
.attention
import
Attention3D
,
SpatialAttention
,
TemporalAttention
from
.common
import
ResidualBlock3D
def
get_mid_block
(
mid_block_type
:
str
,
in_channels
:
int
,
num_layers
:
int
,
act_fn
:
str
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
add_attention
:
bool
=
True
,
attention_type
:
str
=
"3d"
,
num_attention_heads
:
int
=
1
,
output_scale_factor
:
float
=
1.0
,
)
->
nn
.
Module
:
if
mid_block_type
==
"MidBlock3D"
:
return
MidBlock3D
(
in_channels
=
in_channels
,
num_layers
=
num_layers
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
add_attention
=
add_attention
,
attention_type
=
attention_type
,
attention_head_dim
=
in_channels
//
num_attention_heads
,
output_scale_factor
=
output_scale_factor
,
)
else
:
raise
ValueError
(
f
"Unknown mid block type:
{
mid_block_type
}
"
)
class
MidBlock3D
(
nn
.
Module
):
"""
A 3D UNet mid-block [`MidBlock3D`] with multiple residual blocks and optional attention blocks.
Args:
in_channels (`int`): The number of input channels.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
norm_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
attention_type: (`str`, *optional*, defaults to `3d`): The type of attention to use. Defaults to `3d`.
attention_head_dim (`int`, *optional*, defaults to 1):
Dimension of a single attention head. The number of attention heads is determined based on this value and
the number of input channels.
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, temporal_length, height, width)`.
"""
def
__init__
(
self
,
in_channels
:
int
,
num_layers
:
int
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
add_attention
:
bool
=
True
,
attention_type
:
str
=
"3d"
,
attention_head_dim
:
int
=
1
,
output_scale_factor
:
float
=
1.0
,
):
super
().
__init__
()
self
.
attention_type
=
attention_type
norm_num_groups
=
norm_num_groups
if
norm_num_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
self
.
convs
=
nn
.
ModuleList
([
ResidualBlock3D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
non_linearity
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
)
])
self
.
attentions
=
nn
.
ModuleList
([])
for
_
in
range
(
num_layers
-
1
):
if
add_attention
:
if
attention_type
==
"3d"
:
self
.
attentions
.
append
(
Attention3D
(
in_channels
,
nheads
=
in_channels
//
attention_head_dim
,
head_dim
=
attention_head_dim
,
bias
=
True
,
upcast_softmax
=
True
,
norm_num_groups
=
norm_num_groups
,
eps
=
norm_eps
,
rescale_output_factor
=
output_scale_factor
,
residual_connection
=
True
,
)
)
elif
attention_type
==
"spatial_temporal"
:
self
.
attentions
.
append
(
nn
.
ModuleList
([
SpatialAttention
(
in_channels
,
nheads
=
in_channels
//
attention_head_dim
,
head_dim
=
attention_head_dim
,
bias
=
True
,
upcast_softmax
=
True
,
norm_num_groups
=
norm_num_groups
,
eps
=
norm_eps
,
rescale_output_factor
=
output_scale_factor
,
residual_connection
=
True
,
),
TemporalAttention
(
in_channels
,
nheads
=
in_channels
//
attention_head_dim
,
head_dim
=
attention_head_dim
,
bias
=
True
,
upcast_softmax
=
True
,
norm_num_groups
=
norm_num_groups
,
eps
=
norm_eps
,
rescale_output_factor
=
output_scale_factor
,
residual_connection
=
True
,
),
])
)
elif
attention_type
==
"spatial"
:
self
.
attentions
.
append
(
SpatialAttention
(
in_channels
,
nheads
=
in_channels
//
attention_head_dim
,
head_dim
=
attention_head_dim
,
bias
=
True
,
upcast_softmax
=
True
,
norm_num_groups
=
norm_num_groups
,
eps
=
norm_eps
,
rescale_output_factor
=
output_scale_factor
,
residual_connection
=
True
,
)
)
elif
attention_type
==
"temporal"
:
self
.
attentions
.
append
(
TemporalAttention
(
in_channels
,
nheads
=
in_channels
//
attention_head_dim
,
head_dim
=
attention_head_dim
,
bias
=
True
,
upcast_softmax
=
True
,
norm_num_groups
=
norm_num_groups
,
eps
=
norm_eps
,
rescale_output_factor
=
output_scale_factor
,
residual_connection
=
True
,
)
)
else
:
raise
ValueError
(
f
"Unknown attention type:
{
attention_type
}
"
)
else
:
self
.
attentions
.
append
(
None
)
self
.
convs
.
append
(
ResidualBlock3D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
non_linearity
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
)
)
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
hidden_states
=
self
.
convs
[
0
](
hidden_states
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
convs
[
1
:]):
if
attn
is
not
None
:
if
self
.
attention_type
==
"spatial_temporal"
:
spatial_attn
,
temporal_attn
=
attn
hidden_states
=
spatial_attn
(
hidden_states
)
hidden_states
=
temporal_attn
(
hidden_states
)
else
:
hidden_states
=
attn
(
hidden_states
)
hidden_states
=
resnet
(
hidden_states
)
return
hidden_states
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/up_blocks.py
0 → 100755
View file @
08a21d59
import
torch
import
torch.nn
as
nn
from
.attention
import
SpatialAttention
,
TemporalAttention
from
.common
import
ResidualBlock3D
from
.gc_block
import
GlobalContextBlock
from
.upsamplers
import
(
SpatialTemporalUpsampler3D
,
SpatialUpsampler3D
,
TemporalUpsampler3D
)
def
get_up_block
(
up_block_type
:
str
,
in_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
,
act_fn
:
str
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
num_attention_heads
:
int
=
1
,
output_scale_factor
:
float
=
1.0
,
add_gc_block
:
bool
=
False
,
add_upsample
:
bool
=
True
,
)
->
nn
.
Module
:
if
up_block_type
==
"SpatialUpBlock3D"
:
return
SpatialUpBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
num_layers
=
num_layers
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
add_gc_block
=
add_gc_block
,
add_upsample
=
add_upsample
,
)
elif
up_block_type
==
"SpatialAttnUpBlock3D"
:
return
SpatialAttnUpBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
num_layers
=
num_layers
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
attention_head_dim
=
out_channels
//
num_attention_heads
,
output_scale_factor
=
output_scale_factor
,
add_gc_block
=
add_gc_block
,
add_upsample
=
add_upsample
,
)
elif
up_block_type
==
"TemporalUpBlock3D"
:
return
TemporalUpBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
num_layers
=
num_layers
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
add_gc_block
=
add_gc_block
,
add_upsample
=
add_upsample
,
)
elif
up_block_type
==
"TemporalAttnUpBlock3D"
:
return
TemporalAttnUpBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
num_layers
=
num_layers
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
attention_head_dim
=
out_channels
//
num_attention_heads
,
output_scale_factor
=
output_scale_factor
,
add_gc_block
=
add_gc_block
,
add_upsample
=
add_upsample
,
)
elif
up_block_type
==
"SpatialTemporalUpBlock3D"
:
return
SpatialTemporalUpBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
num_layers
=
num_layers
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
add_gc_block
=
add_gc_block
,
add_upsample
=
add_upsample
,
)
else
:
raise
ValueError
(
f
"Unknown up block type:
{
up_block_type
}
"
)
class
SpatialUpBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
output_scale_factor
:
float
=
1.0
,
add_gc_block
:
bool
=
False
,
add_upsample
:
bool
=
True
,
):
super
().
__init__
()
if
add_upsample
:
self
.
upsampler
=
SpatialUpsampler3D
(
in_channels
,
in_channels
)
else
:
self
.
upsampler
=
None
if
add_gc_block
:
self
.
gc_block
=
GlobalContextBlock
(
in_channels
,
in_channels
,
fusion_type
=
"mul"
)
else
:
self
.
gc_block
=
None
self
.
convs
=
nn
.
ModuleList
([])
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
convs
.
append
(
ResidualBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
non_linearity
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
)
)
def
forward
(
self
,
x
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
for
conv
in
self
.
convs
:
x
=
conv
(
x
)
if
self
.
gc_block
is
not
None
:
x
=
self
.
gc_block
(
x
)
if
self
.
upsampler
is
not
None
:
x
=
self
.
upsampler
(
x
)
return
x
class
SpatialAttnUpBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
attention_head_dim
:
int
=
1
,
output_scale_factor
:
float
=
1.0
,
add_gc_block
:
bool
=
False
,
add_upsample
:
bool
=
True
,
):
super
().
__init__
()
self
.
convs
=
nn
.
ModuleList
([])
self
.
attentions
=
nn
.
ModuleList
([])
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
convs
.
append
(
ResidualBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
non_linearity
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
)
)
self
.
attentions
.
append
(
SpatialAttention
(
out_channels
,
nheads
=
out_channels
//
attention_head_dim
,
head_dim
=
attention_head_dim
,
bias
=
True
,
upcast_softmax
=
True
,
norm_num_groups
=
norm_num_groups
,
eps
=
norm_eps
,
rescale_output_factor
=
output_scale_factor
,
residual_connection
=
True
,
)
)
if
add_gc_block
:
self
.
gc_block
=
GlobalContextBlock
(
out_channels
,
out_channels
,
fusion_type
=
"mul"
)
else
:
self
.
gc_block
=
None
if
add_upsample
:
self
.
upsampler
=
SpatialUpsampler3D
(
out_channels
,
out_channels
)
else
:
self
.
upsampler
=
None
def
forward
(
self
,
x
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
for
conv
,
attn
in
zip
(
self
.
convs
,
self
.
attentions
):
x
=
conv
(
x
)
x
=
attn
(
x
)
if
self
.
gc_block
is
not
None
:
x
=
self
.
gc_block
(
x
)
if
self
.
upsampler
is
not
None
:
x
=
self
.
upsampler
(
x
)
return
x
class
TemporalUpBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
output_scale_factor
:
float
=
1.0
,
add_gc_block
:
bool
=
False
,
add_upsample
:
bool
=
True
,
):
super
().
__init__
()
self
.
convs
=
nn
.
ModuleList
([])
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
convs
.
append
(
ResidualBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
non_linearity
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
)
)
if
add_gc_block
:
self
.
gc_block
=
GlobalContextBlock
(
out_channels
,
out_channels
,
fusion_type
=
"mul"
)
else
:
self
.
gc_block
=
None
if
add_upsample
:
self
.
upsampler
=
TemporalUpsampler3D
(
out_channels
,
out_channels
)
else
:
self
.
upsampler
=
None
def
forward
(
self
,
x
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
for
conv
in
self
.
convs
:
x
=
conv
(
x
)
if
self
.
gc_block
is
not
None
:
x
=
self
.
gc_block
(
x
)
if
self
.
upsampler
is
not
None
:
x
=
self
.
upsampler
(
x
)
return
x
class
TemporalAttnUpBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
attention_head_dim
:
int
=
1
,
output_scale_factor
:
float
=
1.0
,
add_gc_block
:
bool
=
False
,
add_upsample
:
bool
=
True
,
):
super
().
__init__
()
self
.
convs
=
nn
.
ModuleList
([])
self
.
attentions
=
nn
.
ModuleList
([])
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
convs
.
append
(
ResidualBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
non_linearity
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
)
)
self
.
attentions
.
append
(
TemporalAttention
(
out_channels
,
nheads
=
out_channels
//
attention_head_dim
,
head_dim
=
attention_head_dim
,
bias
=
True
,
upcast_softmax
=
True
,
norm_num_groups
=
norm_num_groups
,
eps
=
norm_eps
,
rescale_output_factor
=
output_scale_factor
,
residual_connection
=
True
,
)
)
if
add_gc_block
:
self
.
gc_block
=
GlobalContextBlock
(
out_channels
,
out_channels
,
fusion_type
=
"mul"
)
else
:
self
.
gc_block
=
None
if
add_upsample
:
self
.
upsampler
=
TemporalUpsampler3D
(
out_channels
,
out_channels
)
else
:
self
.
upsampler
=
None
def
forward
(
self
,
x
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
for
conv
,
attn
in
zip
(
self
.
convs
,
self
.
attentions
):
x
=
conv
(
x
)
x
=
attn
(
x
)
if
self
.
gc_block
is
not
None
:
x
=
self
.
gc_block
(
x
)
if
self
.
upsampler
is
not
None
:
x
=
self
.
upsampler
(
x
)
return
x
class
SpatialTemporalUpBlock3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
=
1
,
act_fn
:
str
=
"silu"
,
norm_num_groups
:
int
=
32
,
norm_eps
:
float
=
1e-6
,
dropout
:
float
=
0.0
,
output_scale_factor
:
float
=
1.0
,
add_gc_block
:
bool
=
False
,
add_upsample
:
bool
=
True
,
):
super
().
__init__
()
self
.
convs
=
nn
.
ModuleList
([])
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
self
.
convs
.
append
(
ResidualBlock3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
non_linearity
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
norm_eps
=
norm_eps
,
dropout
=
dropout
,
output_scale_factor
=
output_scale_factor
,
)
)
if
add_gc_block
:
self
.
gc_block
=
GlobalContextBlock
(
out_channels
,
out_channels
,
fusion_type
=
"mul"
)
else
:
self
.
gc_block
=
None
if
add_upsample
:
self
.
upsampler
=
SpatialTemporalUpsampler3D
(
out_channels
,
out_channels
)
else
:
self
.
upsampler
=
None
def
forward
(
self
,
x
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
for
conv
in
self
.
convs
:
x
=
conv
(
x
)
if
self
.
gc_block
is
not
None
:
x
=
self
.
gc_block
(
x
)
if
self
.
upsampler
is
not
None
:
x
=
self
.
upsampler
(
x
)
return
x
Ruyi-Models/ruyi/vae/ldm/modules/vaemodules/upsamplers.py
0 → 100644
View file @
08a21d59
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
.common
import
CausalConv3d
class
Upsampler
(
nn
.
Module
):
def
__init__
(
self
,
spatial_upsample_factor
:
int
=
1
,
temporal_upsample_factor
:
int
=
1
,
):
super
().
__init__
()
self
.
spatial_upsample_factor
=
spatial_upsample_factor
self
.
temporal_upsample_factor
=
temporal_upsample_factor
class
SpatialUpsampler3D
(
Upsampler
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
):
super
().
__init__
(
spatial_upsample_factor
=
2
)
if
out_channels
is
None
:
out_channels
=
in_channels
self
.
conv
=
CausalConv3d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
(
1
,
2
,
2
),
mode
=
"nearest"
)
x
=
self
.
conv
(
x
)
return
x
class
SpatialUpsamplerD2S3D
(
Upsampler
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
):
super
().
__init__
(
spatial_upsample_factor
=
2
)
if
out_channels
is
None
:
out_channels
=
in_channels
self
.
conv
=
CausalConv3d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
*
4
,
kernel_size
=
3
,
)
o
,
i
,
t
,
h
,
w
=
self
.
conv
.
weight
.
shape
conv_weight
=
torch
.
empty
(
o
//
4
,
i
,
t
,
h
,
w
)
nn
.
init
.
kaiming_normal_
(
conv_weight
)
conv_weight
=
repeat
(
conv_weight
,
"o ... -> (o 4) ..."
)
self
.
conv
.
weight
.
data
.
copy_
(
conv_weight
)
nn
.
init
.
zeros_
(
self
.
conv
.
bias
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
conv
(
x
)
x
=
rearrange
(
x
,
"b (c p1 p2) t h w -> b c t (h p1) (w p2)"
,
p1
=
2
,
p2
=
2
)
return
x
class
TemporalUpsampler3D
(
Upsampler
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
):
super
().
__init__
(
spatial_upsample_factor
=
1
,
temporal_upsample_factor
=
2
,
)
if
out_channels
is
None
:
out_channels
=
in_channels
self
.
conv
=
CausalConv3d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
x
.
shape
[
2
]
>
1
:
first_frame
,
x
=
x
[:,
:,
:
1
],
x
[:,
:,
1
:]
x
=
F
.
interpolate
(
x
,
scale_factor
=
(
2
,
1
,
1
),
mode
=
"trilinear"
)
x
=
torch
.
cat
([
first_frame
,
x
],
dim
=
2
)
x
=
self
.
conv
(
x
)
return
x
class
TemporalUpsamplerD2S3D
(
Upsampler
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
):
super
().
__init__
(
spatial_upsample_factor
=
1
,
temporal_upsample_factor
=
2
,
)
if
out_channels
is
None
:
out_channels
=
in_channels
self
.
conv
=
CausalConv3d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
*
2
,
kernel_size
=
3
,
)
o
,
i
,
t
,
h
,
w
=
self
.
conv
.
weight
.
shape
conv_weight
=
torch
.
empty
(
o
//
2
,
i
,
t
,
h
,
w
)
nn
.
init
.
kaiming_normal_
(
conv_weight
)
conv_weight
=
repeat
(
conv_weight
,
"o ... -> (o 2) ..."
)
self
.
conv
.
weight
.
data
.
copy_
(
conv_weight
)
nn
.
init
.
zeros_
(
self
.
conv
.
bias
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
conv
(
x
)
x
=
rearrange
(
x
,
"b (c p1) t h w -> b c (t p1) h w"
,
p1
=
2
)
x
=
x
[:,
:,
1
:]
return
x
class
SpatialTemporalUpsampler3D
(
Upsampler
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
):
super
().
__init__
(
spatial_upsample_factor
=
2
,
temporal_upsample_factor
=
2
,
)
if
out_channels
is
None
:
out_channels
=
in_channels
self
.
conv
=
CausalConv3d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
)
self
.
padding_flag
=
0
self
.
set_3dgroupnorm
=
False
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
(
1
,
2
,
2
),
mode
=
"nearest"
)
x
=
self
.
conv
(
x
)
if
self
.
padding_flag
==
0
:
if
x
.
shape
[
2
]
>
1
:
first_frame
,
x
=
x
[:,
:,
:
1
],
x
[:,
:,
1
:]
x
=
F
.
interpolate
(
x
,
scale_factor
=
(
2
,
1
,
1
),
mode
=
"trilinear"
if
not
self
.
set_3dgroupnorm
else
"nearest"
)
x
=
torch
.
cat
([
first_frame
,
x
],
dim
=
2
)
elif
self
.
padding_flag
==
2
or
self
.
padding_flag
==
5
or
self
.
padding_flag
==
6
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
(
2
,
1
,
1
),
mode
=
"trilinear"
if
not
self
.
set_3dgroupnorm
else
"nearest"
)
return
x
class
SpatialTemporalUpsamplerD2S3D
(
Upsampler
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
):
super
().
__init__
(
spatial_upsample_factor
=
2
,
temporal_upsample_factor
=
2
,
)
if
out_channels
is
None
:
out_channels
=
in_channels
self
.
conv
=
CausalConv3d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
*
8
,
kernel_size
=
3
,
)
o
,
i
,
t
,
h
,
w
=
self
.
conv
.
weight
.
shape
conv_weight
=
torch
.
empty
(
o
//
8
,
i
,
t
,
h
,
w
)
nn
.
init
.
kaiming_normal_
(
conv_weight
)
conv_weight
=
repeat
(
conv_weight
,
"o ... -> (o 8) ..."
)
self
.
conv
.
weight
.
data
.
copy_
(
conv_weight
)
nn
.
init
.
zeros_
(
self
.
conv
.
bias
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
conv
(
x
)
x
=
rearrange
(
x
,
"b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)"
,
p1
=
2
,
p2
=
2
,
p3
=
2
)
x
=
x
[:,
:,
1
:]
return
x
Ruyi-Models/ruyi/vae/ldm/util.py
0 → 100644
View file @
08a21d59
import
importlib
import
multiprocessing
as
mp
from
collections
import
abc
from
functools
import
partial
from
inspect
import
isfunction
from
queue
import
Queue
from
threading
import
Thread
import
numpy
as
np
import
torch
from
einops
import
rearrange
from
PIL
import
Image
,
ImageDraw
,
ImageFont
def
log_txt_as_img
(
wh
,
xc
,
size
=
10
):
# wh a tuple of (width, height)
# xc a list of captions to plot
b
=
len
(
xc
)
txts
=
list
()
for
bi
in
range
(
b
):
txt
=
Image
.
new
(
"RGB"
,
wh
,
color
=
"white"
)
draw
=
ImageDraw
.
Draw
(
txt
)
font
=
ImageFont
.
truetype
(
'data/DejaVuSans.ttf'
,
size
=
size
)
nc
=
int
(
40
*
(
wh
[
0
]
/
256
))
lines
=
"
\n
"
.
join
(
xc
[
bi
][
start
:
start
+
nc
]
for
start
in
range
(
0
,
len
(
xc
[
bi
]),
nc
))
try
:
draw
.
text
((
0
,
0
),
lines
,
fill
=
"black"
,
font
=
font
)
except
UnicodeEncodeError
:
print
(
"Cant encode string for logging. Skipping."
)
txt
=
np
.
array
(
txt
).
transpose
(
2
,
0
,
1
)
/
127.5
-
1.0
txts
.
append
(
txt
)
txts
=
np
.
stack
(
txts
)
txts
=
torch
.
tensor
(
txts
)
return
txts
def
ismap
(
x
):
if
not
isinstance
(
x
,
torch
.
Tensor
):
return
False
return
(
len
(
x
.
shape
)
==
4
)
and
(
x
.
shape
[
1
]
>
3
)
def
isimage
(
x
):
if
not
isinstance
(
x
,
torch
.
Tensor
):
return
False
return
(
len
(
x
.
shape
)
==
4
)
and
(
x
.
shape
[
1
]
==
3
or
x
.
shape
[
1
]
==
1
)
def
exists
(
x
):
return
x
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
mean_flat
(
tensor
):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return
tensor
.
mean
(
dim
=
list
(
range
(
1
,
len
(
tensor
.
shape
))))
def
count_params
(
model
,
verbose
=
False
):
total_params
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
if
verbose
:
print
(
f
"
{
model
.
__class__
.
__name__
}
has
{
total_params
*
1.e-6
:.
2
f
}
M params."
)
return
total_params
def
instantiate_from_config
(
config
):
if
not
"target"
in
config
:
if
config
==
'__is_first_stage__'
:
return
None
elif
config
==
"__is_unconditional__"
:
return
None
raise
KeyError
(
"Expected key `target` to instantiate."
)
return
get_obj_from_str
(
config
[
"target"
])(
**
config
.
get
(
"params"
,
dict
()))
def
get_obj_from_str
(
string
,
reload
=
False
):
module
,
cls
=
string
.
rsplit
(
"."
,
1
)
if
reload
:
module_imp
=
importlib
.
import_module
(
module
)
importlib
.
reload
(
module_imp
)
return
getattr
(
importlib
.
import_module
(
module
,
package
=
None
),
cls
)
def
_do_parallel_data_prefetch
(
func
,
Q
,
data
,
idx
,
idx_to_fn
=
False
):
# create dummy dataset instance
# run prefetching
if
idx_to_fn
:
res
=
func
(
data
,
worker_id
=
idx
)
else
:
res
=
func
(
data
)
Q
.
put
([
idx
,
res
])
Q
.
put
(
"Done"
)
def
parallel_data_prefetch
(
func
:
callable
,
data
,
n_proc
,
target_data_type
=
"ndarray"
,
cpu_intensive
=
True
,
use_worker_id
=
False
):
# if target_data_type not in ["ndarray", "list"]:
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if
isinstance
(
data
,
np
.
ndarray
)
and
target_data_type
==
"list"
:
raise
ValueError
(
"list expected but function got ndarray."
)
elif
isinstance
(
data
,
abc
.
Iterable
):
if
isinstance
(
data
,
dict
):
print
(
f
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data
=
list
(
data
.
values
())
if
target_data_type
==
"ndarray"
:
data
=
np
.
asarray
(
data
)
else
:
data
=
list
(
data
)
else
:
raise
TypeError
(
f
"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually
{
type
(
data
)
}
."
)
if
cpu_intensive
:
Q
=
mp
.
Queue
(
1000
)
proc
=
mp
.
Process
else
:
Q
=
Queue
(
1000
)
proc
=
Thread
# spawn processes
if
target_data_type
==
"ndarray"
:
arguments
=
[
[
func
,
Q
,
part
,
i
,
use_worker_id
]
for
i
,
part
in
enumerate
(
np
.
array_split
(
data
,
n_proc
))
]
else
:
step
=
(
int
(
len
(
data
)
/
n_proc
+
1
)
if
len
(
data
)
%
n_proc
!=
0
else
int
(
len
(
data
)
/
n_proc
)
)
arguments
=
[
[
func
,
Q
,
part
,
i
,
use_worker_id
]
for
i
,
part
in
enumerate
(
[
data
[
i
:
i
+
step
]
for
i
in
range
(
0
,
len
(
data
),
step
)]
)
]
processes
=
[]
for
i
in
range
(
n_proc
):
p
=
proc
(
target
=
_do_parallel_data_prefetch
,
args
=
arguments
[
i
])
processes
+=
[
p
]
# start processes
print
(
f
"Start prefetching..."
)
import
time
start
=
time
.
time
()
gather_res
=
[[]
for
_
in
range
(
n_proc
)]
try
:
for
p
in
processes
:
p
.
start
()
k
=
0
while
k
<
n_proc
:
# get result
res
=
Q
.
get
()
if
res
==
"Done"
:
k
+=
1
else
:
gather_res
[
res
[
0
]]
=
res
[
1
]
except
Exception
as
e
:
print
(
"Exception: "
,
e
)
for
p
in
processes
:
p
.
terminate
()
raise
e
finally
:
for
p
in
processes
:
p
.
join
()
print
(
f
"Prefetching complete. [
{
time
.
time
()
-
start
}
sec.]"
)
if
target_data_type
==
'ndarray'
:
if
not
isinstance
(
gather_res
[
0
],
np
.
ndarray
):
return
np
.
concatenate
([
np
.
asarray
(
r
)
for
r
in
gather_res
],
axis
=
0
)
# order outputs
return
np
.
concatenate
(
gather_res
,
axis
=
0
)
elif
target_data_type
==
'list'
:
out
=
[]
for
r
in
gather_res
:
out
.
extend
(
r
)
return
out
else
:
return
gather_res
Ruyi-Models/ruyi/vae/setup.py
0 → 100644
View file @
08a21d59
from
setuptools
import
find_packages
,
setup
setup
(
name
=
'latent-diffusion'
,
version
=
'0.0.1'
,
description
=
''
,
packages
=
find_packages
(),
install_requires
=
[
'torch'
,
'numpy'
,
'tqdm'
,
],
)
\ No newline at end of file
assets/二维码.jpeg
0 → 100644
View file @
08a21d59
106 KB
hf_down.py
0 → 100644
View file @
08a21d59
from
huggingface_hub
import
hf_hub_download
ckpt
=
hf_hub_download
(
repo_id
=
"TencentARC/NVComposer"
,
filename
=
"NVComposer-V0.1.ckpt"
,
repo_type
=
"model"
,
local_dir
=
"./models"
)
start.sh
0 → 100644
View file @
08a21d59
#!/bin/bash
cd
/root/Ruyi-Models
python app.py
启动器.ipynb
0 → 100644
View file @
08a21d59
{
"cells": [
{
"cell_type": "markdown",
"id": "e5c5a211-2ccd-4341-af10-ac546484b91f",
"metadata": {
"tags": []
},
"source": [
"## 项目介绍\n",
"- 原项目地址:https://huggingface.co/IamCreateAI/Ruyi-Mini-7B\n",
"- Ruyi-Mini-7B是一种开源图像转视频生成模型。从输入图像开始,Ruyi生成分辨率从360p到720p的后续视频帧,支持各种宽高比,最长持续时间为5秒。通过运动和摄像头控制增强,Ruyi在视频生成方面提供了更大的灵活性和创造力。\n",
"- 项目在L20显卡,cuda12.2上进行适配\n",
"## 使用说明\n",
"- 启动和重启 Notebook 点上方工具栏中的「重启并运行所有单元格」。出现如下内容就算成功了:\n",
" - `Running on local URL: http://0.0.0.0:7860`\n",
" - `Running on public URL: https://xxxxxxxxxxxxxxx.gradio.live`\n",
"- 通过以下方式开启页面:\n",
" - 控制台打开「自定义服务」了,访问自定义服务端口号设置为7860\n",
" - 直接打开显示的公开链接`public URL`\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53a96614-e2d2-4710-a82b-0d5ca9cb9872",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# 启动\n",
"!sh start.sh"
]
},
{
"cell_type": "markdown",
"source": [
"---\n",
"**扫码关注公众号,获取更多资讯**<br>\n",
"<div align=center>\n",
"<img src=\"assets/二维码.jpeg\" width = 20% />\n",
"</div>\n"
],
"metadata": {
"collapsed": false
},
"id": "2f54158c2967bc25"
},
{
"cell_type": "code",
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "6dc59fbbcf222b6b"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment