Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
635da723
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f912f39b50f87e50a9d99346f5c1b6e644653262"
Commit
635da723
authored
Jun 28, 2022
by
Patrick von Platen
Browse files
one attention module only
parent
79db3eb6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
137 additions
and
107 deletions
+137
-107
src/diffusers/models/attention2d.py
src/diffusers/models/attention2d.py
+132
-58
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+0
-16
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+2
-1
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+3
-32
No files found.
src/diffusers/models/attention2d.py
View file @
635da723
import
math
import
math
import
torch
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
# unet_grad_tts.py
# unet_grad_tts.py
# TODO(Patrick) - weird linear attention layer. Check with: https://github.com/huawei-noah/Speech-Backbones/issues/15
class
LinearAttention
(
torch
.
nn
.
Module
):
class
LinearAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
(
LinearAttention
,
self
).
__init__
()
super
(
LinearAttention
,
self
).
__init__
()
...
@@ -18,7 +18,6 @@ class LinearAttention(torch.nn.Module):
...
@@ -18,7 +18,6 @@ class LinearAttention(torch.nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
qkv
=
self
.
to_qkv
(
x
)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q
,
k
,
v
=
(
q
,
k
,
v
=
(
qkv
.
reshape
(
b
,
3
,
self
.
heads
,
self
.
dim_head
,
h
,
w
)
qkv
.
reshape
(
b
,
3
,
self
.
heads
,
self
.
dim_head
,
h
,
w
)
.
permute
(
1
,
0
,
2
,
3
,
4
,
5
)
.
permute
(
1
,
0
,
2
,
3
,
4
,
5
)
...
@@ -27,12 +26,11 @@ class LinearAttention(torch.nn.Module):
...
@@ -27,12 +26,11 @@ class LinearAttention(torch.nn.Module):
k
=
k
.
softmax
(
dim
=-
1
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out
=
out
.
reshape
(
b
,
self
.
heads
,
self
.
dim_head
,
h
,
w
).
reshape
(
b
,
self
.
heads
*
self
.
dim_head
,
h
,
w
)
out
=
out
.
reshape
(
b
,
self
.
heads
,
self
.
dim_head
,
h
,
w
).
reshape
(
b
,
self
.
heads
*
self
.
dim_head
,
h
,
w
)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
#
unet_glide.py & unet_ldm.py
#
the main attention block that is used for all models
class
AttentionBlock
(
nn
.
Module
):
class
AttentionBlock
(
nn
.
Module
):
"""
"""
An attention block that allows spatial positions to attend to each other.
An attention block that allows spatial positions to attend to each other.
...
@@ -46,10 +44,13 @@ class AttentionBlock(nn.Module):
...
@@ -46,10 +44,13 @@ class AttentionBlock(nn.Module):
channels
,
channels
,
num_heads
=
1
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_head_channels
=-
1
,
num_groups
=
32
,
use_checkpoint
=
False
,
use_checkpoint
=
False
,
encoder_channels
=
None
,
encoder_channels
=
None
,
use_new_attention_order
=
False
,
# TODO(Patrick) -> is never used, maybe delete?
use_new_attention_order
=
False
,
# TODO(Patrick) -> is never used, maybe delete?
overwrite_qkv
=
False
,
overwrite_qkv
=
False
,
overwrite_linear
=
False
,
rescale_output_factor
=
1.0
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -62,23 +63,34 @@ class AttentionBlock(nn.Module):
...
@@ -62,23 +63,34 @@ class AttentionBlock(nn.Module):
self
.
num_heads
=
channels
//
num_head_channels
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
n
ormalization
(
channels
,
swish
=
0.0
)
self
.
norm
=
n
n
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-5
,
affine
=
True
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
self
.
qkv
=
nn
.
Conv1d
(
channels
,
channels
*
3
,
1
)
self
.
n_heads
=
self
.
num_heads
self
.
n_heads
=
self
.
num_heads
self
.
rescale_output_factor
=
rescale_output_factor
if
encoder_channels
is
not
None
:
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
conv_nd
(
1
,
encoder_channels
,
channels
*
2
,
1
)
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
self
.
proj_out
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
self
.
overwrite_qkv
=
overwrite_qkv
self
.
overwrite_qkv
=
overwrite_qkv
if
overwrite_qkv
:
if
overwrite_qkv
:
in_channels
=
channels
in_channels
=
channels
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-6
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
overwrite_linear
=
overwrite_linear
if
self
.
overwrite_linear
:
num_groups
=
min
(
channels
//
4
,
32
)
self
.
norm
=
nn
.
GroupNorm
(
num_channels
=
channels
,
num_groups
=
num_groups
,
eps
=
1e-6
)
self
.
NIN_0
=
NIN
(
channels
,
channels
)
self
.
NIN_1
=
NIN
(
channels
,
channels
)
self
.
NIN_2
=
NIN
(
channels
,
channels
)
self
.
NIN_3
=
NIN
(
channels
,
channels
)
self
.
is_overwritten
=
False
self
.
is_overwritten
=
False
def
set_weights
(
self
,
module
):
def
set_weights
(
self
,
module
):
...
@@ -89,11 +101,17 @@ class AttentionBlock(nn.Module):
...
@@ -89,11 +101,17 @@ class AttentionBlock(nn.Module):
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
zero_module
(
conv_nd
(
1
,
self
.
channels
,
self
.
channels
,
1
))
proj_out
=
zero_module
(
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
))
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
self
.
proj_out
=
proj_out
self
.
proj_out
=
proj_out
elif
self
.
overwrite_linear
:
self
.
qkv
.
weight
.
data
=
torch
.
concat
([
self
.
NIN_0
.
W
.
data
.
T
,
self
.
NIN_1
.
W
.
data
.
T
,
self
.
NIN_2
.
W
.
data
.
T
],
dim
=
0
)[:,
:,
None
]
self
.
qkv
.
bias
.
data
=
torch
.
concat
([
self
.
NIN_0
.
b
.
data
,
self
.
NIN_1
.
b
.
data
,
self
.
NIN_2
.
b
.
data
],
dim
=
0
)
self
.
proj_out
.
weight
.
data
=
self
.
NIN_3
.
W
.
data
.
T
[:,
:,
None
]
self
.
proj_out
.
bias
.
data
=
self
.
NIN_3
.
b
.
data
def
forward
(
self
,
x
,
encoder_out
=
None
):
def
forward
(
self
,
x
,
encoder_out
=
None
):
if
self
.
overwrite_qkv
and
not
self
.
is_overwritten
:
if
self
.
overwrite_qkv
and
not
self
.
is_overwritten
:
...
@@ -124,69 +142,74 @@ class AttentionBlock(nn.Module):
...
@@ -124,69 +142,74 @@ class AttentionBlock(nn.Module):
h
=
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
self
.
proj_out
(
h
)
h
=
self
.
proj_out
(
h
)
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
result
=
x
+
h
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if
dims
==
1
:
return
nn
.
Conv1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
class
GroupNorm32
(
nn
.
GroupNorm
):
def
__init__
(
self
,
num_groups
,
num_channels
,
swish
,
eps
=
1e-5
,
affine
=
True
):
super
().
__init__
(
num_groups
=
num_groups
,
num_channels
=
num_channels
,
eps
=
eps
,
affine
=
affine
)
self
.
swish
=
swish
def
forward
(
self
,
x
):
y
=
super
().
forward
(
x
.
float
()).
to
(
x
.
dtype
)
if
self
.
swish
==
1.0
:
y
=
F
.
silu
(
y
)
elif
self
.
swish
:
y
=
y
*
F
.
sigmoid
(
y
*
float
(
self
.
swish
))
return
y
def
normalization
(
channels
,
swish
=
0.0
,
eps
=
1e-5
):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
,
eps
=
eps
,
affine
=
True
)
result
=
result
/
self
.
rescale_output_factor
def
zero_module
(
module
):
return
result
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
# unet_score_estimation.py
# unet_score_estimation.py
#
class AttnBlockpp(nn.Module):
#class AttnBlockpp(nn.Module):
# """Channel-wise self-attention block. Modified from DDPM."""
# """Channel-wise self-attention block. Modified from DDPM."""
#
#
# def __init__(self, channels, skip_rescale=False, init_scale=0.0):
# def __init__(
# self,
# channels,
# skip_rescale=False,
# init_scale=0.0,
# num_heads=1,
# num_head_channels=-1,
# use_checkpoint=False,
# encoder_channels=None,
# use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
# overwrite_qkv=False,
# overwrite_from_grad_tts=False,
# ):
# super().__init__()
# super().__init__()
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
# num_groups = min(channels // 4, 32)
# self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
# self.NIN_0 = NIN(channels, channels)
# self.NIN_0 = NIN(channels, channels)
# self.NIN_1 = NIN(channels, channels)
# self.NIN_1 = NIN(channels, channels)
# self.NIN_2 = NIN(channels, channels)
# self.NIN_2 = NIN(channels, channels)
# self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
# self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
# self.skip_rescale = skip_rescale
# self.skip_rescale = skip_rescale
#
#
# self.channels = channels
# if num_head_channels == -1:
# self.num_heads = num_heads
# else:
# assert (
# channels % num_head_channels == 0
# ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
# self.num_heads = channels // num_head_channels
#
# self.use_checkpoint = use_checkpoint
# self.norm = normalization(channels, num_groups=num_groups, eps=1e-6, swish=None)
# self.qkv = conv_nd(1, channels, channels * 3, 1)
# self.n_heads = self.num_heads
#
# if encoder_channels is not None:
# self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
#
# self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
#
# self.is_weight_set = False
#
# def set_weights(self):
# self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None]
# self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
#
# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
# self.proj_out.bias.data = self.NIN_3.b.data
#
# def forward(self, x):
# def forward(self, x):
# if not self.is_weight_set:
# self.set_weights()
# self.is_weight_set = True
#
# B, C, H, W = x.shape
# B, C, H, W = x.shape
# h = self.GroupNorm_0(x)
# h = self.GroupNorm_0(x)
# q = self.NIN_0(h)
# q = self.NIN_0(h)
...
@@ -199,7 +222,58 @@ def zero_module(module):
...
@@ -199,7 +222,58 @@ def zero_module(module):
# w = torch.reshape(w, (B, H, W, H, W))
# w = torch.reshape(w, (B, H, W, H, W))
# h = torch.einsum("bhwij,bcij->bchw", w, v)
# h = torch.einsum("bhwij,bcij->bchw", w, v)
# h = self.NIN_3(h)
# h = self.NIN_3(h)
#
# if not self.skip_rescale:
# if not self.skip_rescale:
# re
turn
x + h
# re
sult =
x + h
# else:
# else:
# return (x + h) / np.sqrt(2.0)
# result = (x + h) / np.sqrt(2.0)
#
# result = self.forward_2(x)
#
# return result
#
# def forward_2(self, x, encoder_out=None):
# b, c, *spatial = x.shape
# hid_states = self.norm(x).view(b, c, -1)
#
# qkv = self.qkv(hid_states)
# bs, width, length = qkv.shape
# assert width % (3 * self.n_heads) == 0
# ch = width // (3 * self.n_heads)
# q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
#
# if encoder_out is not None:
# encoder_kv = self.encoder_kv(encoder_out)
# assert encoder_kv.shape[1] == self.n_heads * ch * 2
# ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
# k = torch.cat([ek, k], dim=-1)
# v = torch.cat([ev, v], dim=-1)
#
# scale = 1 / math.sqrt(math.sqrt(ch))
# weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
# weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
#
# a = torch.einsum("bts,bcs->bct", weight, v)
# h = a.reshape(bs, -1, length)
#
# h = self.proj_out(h)
# h = h.reshape(b, c, *spatial)
#
# return (x + h) / np.sqrt(2.0)
# TODO(Patrick) - this can and should be removed
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
class
NIN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
num_units
,
init_scale
=
0.1
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
torch
.
zeros
(
in_dim
,
num_units
),
requires_grad
=
True
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
num_units
),
requires_grad
=
True
)
src/diffusers/models/embeddings.py
View file @
635da723
...
@@ -65,19 +65,3 @@ class GaussianFourierProjection(nn.Module):
...
@@ -65,19 +65,3 @@ class GaussianFourierProjection(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x_proj
=
x
[:,
None
]
*
self
.
W
[
None
,
:]
*
2
*
np
.
pi
x_proj
=
x
[:,
None
]
*
self
.
W
[
None
,
:]
*
2
*
np
.
pi
return
torch
.
cat
([
torch
.
sin
(
x_proj
),
torch
.
cos
(
x_proj
)],
dim
=-
1
)
return
torch
.
cat
([
torch
.
sin
(
x_proj
),
torch
.
cos
(
x_proj
)],
dim
=-
1
)
# unet_rl.py - TODO(need test)
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
)
*
-
emb
)
emb
=
x
[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
src/diffusers/models/unet_grad_tts.py
View file @
635da723
...
@@ -5,6 +5,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -5,6 +5,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
Upsample
from
.resnet
import
Downsample
,
Upsample
from
.attention2d
import
LinearAttention
class
Mish
(
torch
.
nn
.
Module
):
class
Mish
(
torch
.
nn
.
Module
):
...
@@ -54,7 +55,7 @@ class ResnetBlock(torch.nn.Module):
...
@@ -54,7 +55,7 @@ class ResnetBlock(torch.nn.Module):
return
output
return
output
class
LinearAttention
(
torch
.
nn
.
Module
):
class
old_
LinearAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
(
LinearAttention
,
self
).
__init__
()
super
(
LinearAttention
,
self
).
__init__
()
self
.
heads
=
heads
self
.
heads
=
heads
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
635da723
...
@@ -22,10 +22,12 @@ import numpy as np
...
@@ -22,10 +22,12 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
math
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.attention2d
import
AttentionBlock
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
...
@@ -414,37 +416,6 @@ class Combine(nn.Module):
...
@@ -414,37 +416,6 @@ class Combine(nn.Module):
raise
ValueError
(
f
"Method
{
self
.
method
}
not recognized."
)
raise
ValueError
(
f
"Method
{
self
.
method
}
not recognized."
)
class
AttnBlockpp
(
nn
.
Module
):
"""Channel-wise self-attention block. Modified from DDPM."""
def
__init__
(
self
,
channels
,
skip_rescale
=
False
,
init_scale
=
0.0
):
super
().
__init__
()
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
min
(
channels
//
4
,
32
),
num_channels
=
channels
,
eps
=
1e-6
)
self
.
NIN_0
=
NIN
(
channels
,
channels
)
self
.
NIN_1
=
NIN
(
channels
,
channels
)
self
.
NIN_2
=
NIN
(
channels
,
channels
)
self
.
NIN_3
=
NIN
(
channels
,
channels
,
init_scale
=
init_scale
)
self
.
skip_rescale
=
skip_rescale
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
h
=
self
.
GroupNorm_0
(
x
)
q
=
self
.
NIN_0
(
h
)
k
=
self
.
NIN_1
(
h
)
v
=
self
.
NIN_2
(
h
)
w
=
torch
.
einsum
(
"bchw,bcij->bhwij"
,
q
,
k
)
*
(
int
(
C
)
**
(
-
0.5
))
w
=
torch
.
reshape
(
w
,
(
B
,
H
,
W
,
H
*
W
))
w
=
F
.
softmax
(
w
,
dim
=-
1
)
w
=
torch
.
reshape
(
w
,
(
B
,
H
,
W
,
H
,
W
))
h
=
torch
.
einsum
(
"bhwij,bcij->bchw"
,
w
,
v
)
h
=
self
.
NIN_3
(
h
)
if
not
self
.
skip_rescale
:
return
x
+
h
else
:
return
(
x
+
h
)
/
np
.
sqrt
(
2.0
)
class
Upsample
(
nn
.
Module
):
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
def
__init__
(
self
,
in_ch
=
None
,
out_ch
=
None
,
with_conv
=
False
,
fir
=
False
,
fir_kernel
=
(
1
,
3
,
3
,
1
)):
super
().
__init__
()
super
().
__init__
()
...
@@ -756,7 +727,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -756,7 +727,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules
[
-
1
].
weight
.
data
=
default_init
()(
modules
[
-
1
].
weight
.
shape
)
modules
[
-
1
].
weight
.
data
=
default_init
()(
modules
[
-
1
].
weight
.
shape
)
nn
.
init
.
zeros_
(
modules
[
-
1
].
bias
)
nn
.
init
.
zeros_
(
modules
[
-
1
].
bias
)
AttnBlock
=
functools
.
partial
(
AttnBlock
pp
,
init_scale
=
init_scale
,
skip_rescale
=
skip_rescale
)
AttnBlock
=
functools
.
partial
(
Att
entio
nBlock
,
overwrite_linear
=
True
,
rescale_output_factor
=
math
.
sqrt
(
2.0
)
)
Up_sample
=
functools
.
partial
(
Upsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
Up_sample
=
functools
.
partial
(
Upsample
,
with_conv
=
resamp_with_conv
,
fir
=
fir
,
fir_kernel
=
fir_kernel
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment