Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
23904d54
Commit
23904d54
authored
Jul 01, 2022
by
patil-suraj
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into conversion-scripts
parents
32b93da8
c691bb2f
Changes
9
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
476 additions
and
1334 deletions
+476
-1334
setup.py
setup.py
+1
-1
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+1
-0
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+153
-276
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+0
-42
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+53
-60
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+141
-507
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+3
-21
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+104
-405
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+20
-22
No files found.
setup.py
View file @
23904d54
...
@@ -88,7 +88,7 @@ _deps = [
...
@@ -88,7 +88,7 @@ _deps = [
"requests"
,
"requests"
,
"torch>=1.4"
,
"torch>=1.4"
,
"tensorboard"
,
"tensorboard"
,
"modelcards=0.1.4"
"modelcards=
=
0.1.4"
]
]
# this is a lookup table with items like:
# this is a lookup table with items like:
...
...
src/diffusers/dependency_versions_table.py
View file @
23904d54
...
@@ -14,4 +14,5 @@ deps = {
...
@@ -14,4 +14,5 @@ deps = {
"requests"
:
"requests"
,
"requests"
:
"requests"
,
"torch"
:
"torch>=1.4"
,
"torch"
:
"torch>=1.4"
,
"tensorboard"
:
"tensorboard"
,
"tensorboard"
:
"tensorboard"
,
"modelcards"
:
"modelcards==0.1.4"
,
}
}
src/diffusers/models/resnet.py
View file @
23904d54
This diff is collapsed.
Click to expand it.
src/diffusers/models/unet.py
View file @
23904d54
...
@@ -34,48 +34,6 @@ def Normalize(in_channels):
...
@@ -34,48 +34,6 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
# class ResnetBlock(nn.Module):
# def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
# super().__init__()
# self.in_channels = in_channels
# out_channels = in_channels if out_channels is None else out_channels
# self.out_channels = out_channels
# self.use_conv_shortcut = conv_shortcut
#
# self.norm1 = Normalize(in_channels)
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
# self.norm2 = Normalize(out_channels)
# self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# else:
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
# def forward(self, x, temb):
# h = x
# h = self.norm1(h)
# h = nonlinearity(h)
# h = self.conv1(h)
#
# h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
#
# h = self.norm2(h)
# h = nonlinearity(h)
# h = self.dropout(h)
# h = self.conv2(h)
#
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
# x = self.nin_shortcut(x)
#
# return x + h
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
class
UNetModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
def
__init__
(
self
,
self
,
...
...
src/diffusers/models/unet_glide.py
View file @
23904d54
...
@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -6,7 +6,7 @@ from ..configuration_utils import ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.attention
import
AttentionBlock
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Downsample
,
ResBlock
,
TimestepBlock
,
Upsample
from
.resnet
import
Downsample
,
Res
net
Block
,
TimestepBlock
,
Upsample
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -29,19 +29,6 @@ def convert_module_to_f32(l):
...
@@ -29,19 +29,6 @@ def convert_module_to_f32(l):
l
.
bias
.
data
=
l
.
bias
.
data
.
float
()
l
.
bias
.
data
=
l
.
bias
.
data
.
float
()
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if
dims
==
1
:
return
nn
.
AvgPool1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
AvgPool2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
"""
Create a 1D, 2D, or 3D convolution module.
Create a 1D, 2D, or 3D convolution module.
...
@@ -101,7 +88,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -101,7 +88,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
for
layer
in
self
:
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
if
isinstance
(
layer
,
TimestepBlock
)
or
isinstance
(
layer
,
ResnetBlock
)
:
x
=
layer
(
x
,
emb
)
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
AttentionBlock
):
elif
isinstance
(
layer
,
AttentionBlock
):
x
=
layer
(
x
,
encoder_out
)
x
=
layer
(
x
,
encoder_out
)
...
@@ -190,14 +177,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -190,14 +177,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
layers
=
[
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
out_channels
=
mult
*
model_channels
,
dropout
,
dropout
=
dropout
,
out_channels
=
int
(
mult
*
model_channels
),
temb_channels
=
time_embed_dim
,
dims
=
dims
,
eps
=
1e-5
,
use_checkpoint
=
use_checkpoint
,
non_linearity
=
"silu"
,
use_scale_shift_norm
=
use_scale_shift_norm
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
)
)
]
]
ch
=
int
(
mult
*
model_channels
)
ch
=
int
(
mult
*
model_channels
)
...
@@ -218,14 +206,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -218,14 +206,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
out_ch
=
ch
out_ch
=
ch
self
.
input_blocks
.
append
(
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
TimestepEmbedSequential
(
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
dims
=
dims
,
dropout
=
dropout
,
use_checkpoint
=
use_checkpoint
,
temb_channels
=
time_embed_dim
,
use_scale_shift_norm
=
use_scale_shift_norm
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
down
=
True
,
down
=
True
,
)
)
if
resblock_updown
if
resblock_updown
...
@@ -240,13 +229,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -240,13 +229,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
=
dropout
,
dropout
,
temb_channels
=
time_embed_dim
,
dims
=
dims
,
eps
=
1e-5
,
use_checkpoint
=
use_checkpoint
,
non_linearity
=
"silu"
,
use_scale_shift_norm
=
use_scale_shift_norm
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
),
),
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
...
@@ -255,13 +245,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -255,13 +245,14 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
transformer_dim
,
encoder_channels
=
transformer_dim
,
),
),
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
=
dropout
,
dropout
,
temb_channels
=
time_embed_dim
,
dims
=
dims
,
eps
=
1e-5
,
use_checkpoint
=
use_checkpoint
,
non_linearity
=
"silu"
,
use_scale_shift_norm
=
use_scale_shift_norm
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
),
),
)
)
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
@@ -271,15 +262,16 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -271,15 +262,16 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for
i
in
range
(
num_res_blocks
+
1
):
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
ich
=
input_block_chans
.
pop
()
layers
=
[
layers
=
[
ResBlock
(
ResnetBlock
(
ch
+
ich
,
in_channels
=
ch
+
ich
,
time_embed_dim
,
out_channels
=
model_channels
*
mult
,
dropout
,
dropout
=
dropout
,
out_channels
=
int
(
model_channels
*
mult
),
temb_channels
=
time_embed_dim
,
dims
=
dims
,
eps
=
1e-5
,
use_checkpoint
=
use_checkpoint
,
non_linearity
=
"silu"
,
use_scale_shift_norm
=
use_scale_shift_norm
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
)
overwrite_for_glide
=
True
,
),
]
]
ch
=
int
(
model_channels
*
mult
)
ch
=
int
(
model_channels
*
mult
)
if
ds
in
attention_resolutions
:
if
ds
in
attention_resolutions
:
...
@@ -295,14 +287,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -295,14 +287,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
if
level
and
i
==
num_res_blocks
:
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
out_ch
=
ch
layers
.
append
(
layers
.
append
(
ResBlock
(
ResnetBlock
(
ch
,
in_channels
=
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
out_channels
=
out_ch
,
dims
=
dims
,
dropout
=
dropout
,
use_checkpoint
=
use_checkpoint
,
temb_channels
=
time_embed_dim
,
use_scale_shift_norm
=
use_scale_shift_norm
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
time_embedding_norm
=
"scale_shift"
if
use_scale_shift_norm
else
"default"
,
overwrite_for_glide
=
True
,
up
=
True
,
up
=
True
,
)
)
if
resblock_updown
if
resblock_updown
...
...
src/diffusers/models/unet_ldm.py
View file @
23904d54
This diff is collapsed.
Click to expand it.
src/diffusers/models/unet_rl.py
View file @
23904d54
...
@@ -6,7 +6,7 @@ import torch.nn as nn
...
@@ -6,7 +6,7 @@ import torch.nn as nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
ResidualTemporalBlock
from
.resnet
import
Downsample
,
ResidualTemporalBlock
,
Upsample
class
SinusoidalPosEmb
(
nn
.
Module
):
class
SinusoidalPosEmb
(
nn
.
Module
):
...
@@ -18,24 +18,6 @@ class SinusoidalPosEmb(nn.Module):
...
@@ -18,24 +18,6 @@ class SinusoidalPosEmb(nn.Module):
return
get_timestep_embedding
(
x
,
self
.
dim
)
return
get_timestep_embedding
(
x
,
self
.
dim
)
class
Downsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv1d
(
dim
,
dim
,
3
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
ConvTranspose1d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
RearrangeDim
(
nn
.
Module
):
class
RearrangeDim
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
@@ -114,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -114,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[
[
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
Downsample
1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
Downsample
(
dim_out
,
use_conv
=
True
,
dims
=
1
)
if
not
is_last
else
nn
.
Identity
(),
]
]
)
)
)
)
...
@@ -134,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -134,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[
[
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
training_horizon
),
Upsample
1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
Upsample
(
dim_in
,
use_conv_transpose
=
True
,
dims
=
1
)
if
not
is_last
else
nn
.
Identity
(),
]
]
)
)
)
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
23904d54
This diff is collapsed.
Click to expand it.
tests/test_modeling_utils.py
View file @
23904d54
...
@@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
@@ -607,7 +607,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -607,7 +607,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_slice
=
torch
.
tensor
([
-
0.0690
,
-
0.0531
,
0.0633
,
-
0.0660
,
-
0.0541
,
0.0650
,
-
0.0656
,
-
0.0555
,
0.0617
])
expected_output_slice
=
torch
.
tensor
([
-
0.0690
,
-
0.0531
,
0.0633
,
-
0.0660
,
-
0.0541
,
0.0650
,
-
0.0656
,
-
0.0555
,
0.0617
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-3
))
class
TemporalUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
TemporalUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
@@ -678,7 +678,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -678,7 +678,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_slice
=
torch
.
tensor
([
-
0.2714
,
0.1042
,
-
0.0794
,
-
0.2820
,
0.0803
,
-
0.0811
,
-
0.2345
,
0.0580
,
-
0.0584
])
expected_output_slice
=
torch
.
tensor
([
-
0.2714
,
0.1042
,
-
0.0794
,
-
0.2820
,
0.0803
,
-
0.0811
,
-
0.2345
,
0.0580
,
-
0.0584
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-3
))
class
NCSNppModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
NCSNppModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
@@ -742,18 +742,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -742,18 +742,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels
=
3
num_channels
=
3
sizes
=
(
32
,
32
)
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
torch
.
ones
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
0
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
e-4
]).
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
3
.1
909e-07
,
-
8.5393e-08
,
4.8460e-07
,
-
4.5550e-07
,
-
1.3205e-06
,
-
6.3475e-07
,
9.7837e-07
,
2.9974e-07
,
1.2345e-0
6
])
expected_output_slice
=
torch
.
tensor
([
0
.1
315
,
0.0741
,
0.0393
,
0.0455
,
0.0556
,
0.0180
,
-
0.0832
,
-
0.0644
,
-
0.085
6
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
def
test_output_pretrained_ve_large
(
self
):
def
test_output_pretrained_ve_large
(
self
):
model
=
NCSNpp
.
from_pretrained
(
"fusing/ncsnpp-ffhq-ve-dummy"
)
model
=
NCSNpp
.
from_pretrained
(
"fusing/ncsnpp-ffhq-ve-dummy"
)
...
@@ -768,21 +768,21 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -768,21 +768,21 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels
=
3
num_channels
=
3
sizes
=
(
32
,
32
)
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
torch
.
ones
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
0
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
e-4
]).
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
8.3299e-07
,
-
9.0431e-07
,
4.0585e-08
,
9.7563e-07
,
1.0280e-06
,
1.0133e-06
,
1.4979e-06
,
-
2.9716e-07
,
-
6.1817e-07
])
expected_output_slice
=
torch
.
tensor
([
-
0.0325
,
-
0.0900
,
-
0.0869
,
-
0.0332
,
-
0.0725
,
-
0.0270
,
-
0.0101
,
0.0227
,
0.0256
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
def
test_output_pretrained_vp
(
self
):
def
test_output_pretrained_vp
(
self
):
model
=
NCSNpp
.
from_pretrained
(
"fusing/
ddpm-
cifar10-
vp-dummy
"
)
model
=
NCSNpp
.
from_pretrained
(
"fusing/cifar10-
ddpmpp-vp
"
)
model
.
eval
()
model
.
eval
()
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
...
@@ -794,18 +794,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -794,18 +794,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels
=
3
num_channels
=
3
sizes
=
(
32
,
32
)
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
torch
.
randn
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
0
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
9.
0
]).
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
3.9086e-07
,
-
1.1001e-0
5
,
1
.88
81e-06
,
1.1106e-05
,
1.
6629e-06
,
2.9820e-06
,
8
.4
978e-06
,
8.0253e-07
,
1.
5435e-0
6
])
expected_output_slice
=
torch
.
tensor
([
0.3303
,
-
0.227
5
,
-
2
.88
72
,
-
0.1309
,
-
1.
2861
,
3
.4
567
,
-
1.0083
,
2.5325
,
-
1.
386
6
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
class
VQModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
VQModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
@@ -878,10 +878,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -878,10 +878,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
1.1321
,
0.1056
,
0.3505
,
-
0.6461
,
-
0.2014
,
0.0419
,
-
0.5763
,
-
0.8462
,
expected_output_slice
=
torch
.
tensor
([
-
1.1321
,
0.1056
,
0.3505
,
-
0.6461
,
-
0.2014
,
0.0419
,
-
0.5763
,
-
0.8462
,
-
0.4218
])
-
0.4218
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
class
AutoEncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
AutoEncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
@@ -950,10 +949,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
...
@@ -950,10 +949,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.0814
,
-
0.0229
,
-
0.1320
,
-
0.4123
,
-
0.0366
,
-
0.3473
,
0.0438
,
-
0.1662
,
expected_output_slice
=
torch
.
tensor
([
-
0.0814
,
-
0.0229
,
-
0.1320
,
-
0.4123
,
-
0.0366
,
-
0.3473
,
0.0438
,
-
0.1662
,
0.1750
])
0.1750
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
a
tol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
r
tol
=
1e-
2
))
class
PipelineTesterMixin
(
unittest
.
TestCase
):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment