Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ee010726
Commit
ee010726
authored
Jun 27, 2022
by
patil-suraj
Browse files
cleanup
parent
abcb2597
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
86 deletions
+5
-86
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+0
-82
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+5
-4
No files found.
src/diffusers/models/resnet.py
View file @
ee010726
...
@@ -125,88 +125,6 @@ class Downsample(nn.Module):
...
@@ -125,88 +125,6 @@ class Downsample(nn.Module):
return
self
.
down
(
x
)
return
self
.
down
(
x
)
class
UNetUpsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
GlideUpsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
LDMUpsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
GradTTSUpsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
GradTTSUpsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
# TODO (patil-suraj): needs test
# TODO (patil-suraj): needs test
class
Upsample1d
(
nn
.
Module
):
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
...
...
src/diffusers/models/unet_ldm.py
View file @
ee010726
...
@@ -82,7 +82,7 @@ def Normalize(in_channels):
...
@@ -82,7 +82,7 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
#class LinearAttention(nn.Module):
#
class LinearAttention(nn.Module):
# def __init__(self, dim, heads=4, dim_head=32):
# def __init__(self, dim, heads=4, dim_head=32):
# super().__init__()
# super().__init__()
# self.heads = heads
# self.heads = heads
...
@@ -102,7 +102,7 @@ def Normalize(in_channels):
...
@@ -102,7 +102,7 @@ def Normalize(in_channels):
# return self.to_out(out)
# return self.to_out(out)
#
#
#class SpatialSelfAttention(nn.Module):
#
class SpatialSelfAttention(nn.Module):
# def __init__(self, in_channels):
# def __init__(self, in_channels):
# super().__init__()
# super().__init__()
# self.in_channels = in_channels
# self.in_channels = in_channels
...
@@ -120,7 +120,7 @@ def Normalize(in_channels):
...
@@ -120,7 +120,7 @@ def Normalize(in_channels):
# k = self.k(h_)
# k = self.k(h_)
# v = self.v(h_)
# v = self.v(h_)
#
#
# compute attention
# compute attention
# b, c, h, w = q.shape
# b, c, h, w = q.shape
# q = rearrange(q, "b c h w -> b (h w) c")
# q = rearrange(q, "b c h w -> b (h w) c")
# k = rearrange(k, "b c h w -> b c (h w)")
# k = rearrange(k, "b c h w -> b c (h w)")
...
@@ -129,7 +129,7 @@ def Normalize(in_channels):
...
@@ -129,7 +129,7 @@ def Normalize(in_channels):
# w_ = w_ * (int(c) ** (-0.5))
# w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2)
# w_ = torch.nn.functional.softmax(w_, dim=2)
#
#
# attend to values
# attend to values
# v = rearrange(v, "b c h w -> b c (h w)")
# v = rearrange(v, "b c h w -> b c (h w)")
# w_ = rearrange(w_, "b i j -> b j i")
# w_ = rearrange(w_, "b i j -> b j i")
# h_ = torch.einsum("bij,bjk->bik", v, w_)
# h_ = torch.einsum("bij,bjk->bik", v, w_)
...
@@ -139,6 +139,7 @@ def Normalize(in_channels):
...
@@ -139,6 +139,7 @@ def Normalize(in_channels):
# return x + h_
# return x + h_
#
#
class
CrossAttention
(
nn
.
Module
):
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
super
().
__init__
()
super
().
__init__
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment