Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
02a76c2c
Commit
02a76c2c
authored
Jun 27, 2022
by
Patrick von Platen
Browse files
consolidate timestep embeds
parent
014ebc59
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
173 additions
and
853 deletions
+173
-853
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+84
-65
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+22
-21
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+23
-22
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+22
-30
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+12
-11
tests/test_layers_utils.py
tests/test_layers_utils.py
+10
-704
No files found.
src/diffusers/models/embeddings.py
View file @
02a76c2c
...
...
@@ -11,49 +11,104 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
math
import
numpy
as
np
from
torch
import
nn
import
torch.nn.functional
as
F
# unet.py
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
False
,
downscale_freq_shift
=
1
,
max_period
=
10000
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
/
(
embedding_dim
//
2
-
downscale_freq_shift
)
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
timesteps
[:,
None
].
float
()
*
emb
[
None
,
:]
# concat sine and cosine embeddings
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
# flip sine and cosine embeddings
if
flip_sin_to_cos
:
emb
=
torch
.
cat
([
emb
[:,
half_dim
:],
emb
[:,
:
half_dim
]],
dim
=-
1
)
# zero pad
if
embedding_dim
%
2
==
1
:
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
# unet_glide.py
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
#def get_timestep_embedding(timesteps, embedding_dim):
# """
# This matches the implementation in Denoising Diffusion Probabilistic Models:
# From Fairseq.
# Build sinusoidal embeddings.
# This matches the implementation in tensor2tensor, but differs slightly
# from the description in Section 3.5 of "Attention Is All You Need".
# """
# assert len(timesteps.shape) == 1
#
# half_dim = embedding_dim // 2
# emb = math.log(10000) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
# emb = emb.to(device=timesteps.device)
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
#def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None, :]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
#def a_get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
# half_dim = embedding_dim // 2
# magic number 10000 is from transformers
# emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = F.pad(emb, (0, 1), mode="constant")
# assert emb.shape == (timesteps.shape[0], embedding_dim)
# return emb
# unet_grad_tts.py
class
SinusoidalPosEmb
(
torch
.
nn
.
Module
):
...
...
@@ -70,26 +125,6 @@ class SinusoidalPosEmb(torch.nn.Module):
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
# unet_ldm.py
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
# unet_rl.py
class
SinusoidalPosEmb
(
nn
.
Module
):
...
...
@@ -106,22 +141,6 @@ class SinusoidalPosEmb(nn.Module):
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
# unet_sde_score_estimation.py
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
max_positions
=
10000
):
assert
len
(
timesteps
.
shape
)
==
1
# and timesteps.dtype == tf.int32
half_dim
=
embedding_dim
//
2
# magic number 10000 is from transformers
emb
=
math
.
log
(
max_positions
)
/
(
half_dim
-
1
)
# emb = math.log(2.) / (half_dim - 1)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
,
device
=
timesteps
.
device
)
*
-
emb
)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
F
.
pad
(
emb
,
(
0
,
1
),
mode
=
"constant"
)
assert
emb
.
shape
==
(
timesteps
.
shape
[
0
],
embedding_dim
)
return
emb
# unet_sde_score_estimation.py
class
GaussianFourierProjection
(
nn
.
Module
):
...
...
src/diffusers/models/unet.py
View file @
02a76c2c
...
...
@@ -30,27 +30,28 @@ from tqdm import tqdm
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
from
.embeddings
import
get_timestep_embedding
#def get_timestep_embedding(timesteps, embedding_dim):
# """
# This matches the implementation in Denoising Diffusion Probabilistic Models:
# From Fairseq.
# Build sinusoidal embeddings.
# This matches the implementation in tensor2tensor, but differs slightly
# from the description in Section 3.5 of "Attention Is All You Need".
# """
# assert len(timesteps.shape) == 1
#
# half_dim = embedding_dim // 2
# emb = math.log(10000) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
# emb = emb.to(device=timesteps.device)
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
# return emb
def
nonlinearity
(
x
):
...
...
src/diffusers/models/unet_glide.py
View file @
02a76c2c
...
...
@@ -7,6 +7,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
def
convert_module_to_f16
(
l
):
...
...
@@ -86,25 +87,25 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
#
def timestep_embedding(timesteps, dim, max_period=10000):
#
"""
#
Create sinusoidal timestep embeddings.
#
#
:param timesteps: a 1-D Tensor of N indices, one per batch element.
#
These may be fractional.
#
:param dim: the dimension of the output.
#
:param max_period: controls the minimum frequency of the embeddings.
#
:return: an [N x dim] Tensor of positional embeddings.
#
"""
#
half = dim // 2
#
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
#
device=timesteps.device
#
)
#
args = timesteps[:, None].float() * freqs[None]
#
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
#
if dim % 2:
#
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
#
return embedding
def
zero_module
(
module
):
...
...
@@ -627,7 +628,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
"""
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
h
=
x
.
type
(
self
.
dtype
)
for
module
in
self
.
input_blocks
:
...
...
@@ -714,7 +715,7 @@ class GlideTextToImageUNetModel(GlideUNetModel):
def
forward
(
self
,
x
,
timesteps
,
transformer_out
=
None
):
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
# project the last token
transformer_proj
=
self
.
transformer_proj
(
transformer_out
[:,
-
1
])
...
...
@@ -806,7 +807,7 @@ class GlideSuperResUNetModel(GlideUNetModel):
x
=
torch
.
cat
([
x
,
upsampled
],
dim
=
1
)
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
h
=
x
for
module
in
self
.
input_blocks
:
...
...
src/diffusers/models/unet_ldm.py
View file @
02a76c2c
...
...
@@ -16,6 +16,7 @@ except:
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
def
exists
(
val
):
...
...
@@ -316,34 +317,25 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
#def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
## go
...
...
@@ -1026,7 +1018,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
hs
=
[]
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
x
.
device
)
t_emb
=
timestep_embedding
(
timesteps
,
self
.
model_channels
)
t_emb
=
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
emb
=
self
.
time_embed
(
t_emb
)
if
self
.
num_classes
is
not
None
:
...
...
@@ -1240,7 +1232,7 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
results
=
[]
h
=
x
.
type
(
self
.
dtype
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
02a76c2c
...
...
@@ -26,6 +26,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
...
...
@@ -381,21 +382,21 @@ def get_act(nonlinearity):
raise
NotImplementedError
(
"activation function does not exist!"
)
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
max_positions
=
10000
):
assert
len
(
timesteps
.
shape
)
==
1
# and timesteps.dtype == tf.int32
half_dim
=
embedding_dim
//
2
#
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
#
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
#
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb
=
math
.
log
(
max_positions
)
/
(
half_dim
-
1
)
#
emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
,
device
=
timesteps
.
device
)
*
-
emb
)
#
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
F
.
pad
(
emb
,
(
0
,
1
),
mode
=
"constant"
)
assert
emb
.
shape
==
(
timesteps
.
shape
[
0
],
embedding_dim
)
return
emb
#
emb = timesteps.float()[:, None] * emb[None, :]
#
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
#
if embedding_dim % 2 == 1: # zero pad
#
emb = F.pad(emb, (0, 1), mode="constant")
#
assert emb.shape == (timesteps.shape[0], embedding_dim)
#
return emb
def
default_init
(
scale
=
1.0
):
...
...
tests/test_layers_utils.py
View file @
02a76c2c
...
...
@@ -21,718 +21,24 @@ import unittest
import
numpy
as
np
import
torch
from
diffusers
import
(
BDDMPipeline
,
DDIMPipeline
,
DDIMScheduler
,
DDPMPipeline
,
DDPMScheduler
,
GlidePipeline
,
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GradTTSPipeline
,
GradTTSScheduler
,
LatentDiffusionPipeline
,
PNDMPipeline
,
PNDMScheduler
,
UNetGradTTSModel
,
UNetLDMModel
,
UNetModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipelines.pipeline_bddm
import
DiffWave
#from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding, a_get_timestep_embedding
from
diffusers.models.embeddings
import
get_timestep_embedding
,
timestep_embedding
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
class
ConfigTester
(
unittest
.
TestCase
):
def
test_load_not_from_mixin
(
self
):
with
self
.
assertRaises
(
ValueError
):
ConfigMixin
.
from_config
(
"dummy_path"
)
class
EmbeddingsTests
(
unittest
.
TestCase
):
def
test_
save_load
(
self
):
class
SampleObject
(
ConfigMixin
):
config_name
=
"config.json"
def
test_
timestep_embeddings
(
self
):
embedding_dim
=
16
timesteps
=
torch
.
arange
(
10
)
def
__init__
(
self
,
a
=
2
,
b
=
5
,
c
=
(
2
,
5
),
d
=
"for diffusion"
,
e
=
[
1
,
3
],
):
self
.
register_to_config
(
a
=
a
,
b
=
b
,
c
=
c
,
d
=
d
,
e
=
e
)
t1
=
get_timestep_embedding
(
timesteps
,
embedding_dim
)
t2
=
timestep_embedding
(
timesteps
,
embedding_dim
)
t3
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
True
,
downscale_freq_factor
=
8
)
obj
=
SampleObject
()
config
=
obj
.
config
import
ipdb
;
ipdb
.
set_trace
()
assert
config
[
"a"
]
==
2
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
(
2
,
5
)
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
new_obj
=
SampleObject
.
from_config
(
tmpdirname
)
new_config
=
new_obj
.
config
# unfreeze configs
config
=
dict
(
config
)
new_config
=
dict
(
new_config
)
assert
config
.
pop
(
"c"
)
==
(
2
,
5
)
# instantiated as tuple
assert
new_config
.
pop
(
"c"
)
==
[
2
,
5
]
# saved & loaded as list because of json
assert
config
==
new_config
class
ModelTesterMixin
:
def
test_from_pretrained_save_pretrained
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_pretrained
(
tmpdirname
)
new_model
.
to
(
torch_device
)
with
torch
.
no_grad
():
image
=
model
(
**
inputs_dict
)
new_image
=
new_model
(
**
inputs_dict
)
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
self
.
assertLessEqual
(
max_diff
,
1e-5
,
"Models give different forward passes"
)
def
test_determinism
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
first
=
model
(
**
inputs_dict
)
second
=
model
(
**
inputs_dict
)
out_1
=
first
.
cpu
().
numpy
()
out_2
=
second
.
cpu
().
numpy
()
out_1
=
out_1
[
~
np
.
isnan
(
out_1
)]
out_2
=
out_2
[
~
np
.
isnan
(
out_2
)]
max_diff
=
np
.
amax
(
np
.
abs
(
out_1
-
out_2
))
self
.
assertLessEqual
(
max_diff
,
1e-5
)
def
test_output
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_forward_signature
(
self
):
init_dict
,
_
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
signature
=
inspect
.
signature
(
model
.
forward
)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names
=
[
*
signature
.
parameters
.
keys
()]
expected_arg_names
=
[
"x"
,
"timesteps"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_config
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
# test if the model can be loaded from the config
# and has all the expected shape
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_config
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_config
(
tmpdirname
)
new_model
.
to
(
torch_device
)
new_model
.
eval
()
# check if all paramters shape are the same
for
param_name
in
model
.
state_dict
().
keys
():
param_1
=
model
.
state_dict
()[
param_name
]
param_2
=
new_model
.
state_dict
()[
param_name
]
self
.
assertEqual
(
param_1
.
shape
,
param_2
.
shape
)
with
torch
.
no_grad
():
output_1
=
model
(
**
inputs_dict
)
output_2
=
new_model
(
**
inputs_dict
)
self
.
assertEqual
(
output_1
.
shape
,
output_2
.
shape
)
def
test_training
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
train
()
output
=
model
(
**
inputs_dict
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],)
+
self
.
get_output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
class
UnetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
3
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
get_input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"ch"
:
32
,
"ch_mult"
:
(
1
,
2
),
"num_res_blocks"
:
2
,
"attn_resolutions"
:
(
16
,),
"resolution"
:
32
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
GlideSuperResUNetModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
6
sizes
=
(
32
,
32
)
low_res_size
=
(
4
,
4
)
noise
=
torch
.
randn
((
batch_size
,
num_channels
//
2
)
+
sizes
).
to
(
torch_device
)
low_res
=
torch
.
randn
((
batch_size
,
3
)
+
low_res_size
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"low_res"
:
low_res
}
@
property
def
get_input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
6
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"attention_resolutions"
:
(
2
,),
"channel_mult"
:
(
1
,
2
),
"in_channels"
:
6
,
"out_channels"
:
6
,
"model_channels"
:
32
,
"num_head_channels"
:
8
,
"num_heads_upsample"
:
1
,
"num_res_blocks"
:
2
,
"resblock_updown"
:
True
,
"resolution"
:
32
,
"use_scale_shift_norm"
:
True
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_output
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
GlideSuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
GlideSuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
3
,
64
,
64
)
low_res
=
torch
.
randn
(
1
,
3
,
4
,
4
)
time_step
=
torch
.
tensor
([
42
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
low_res
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
22.8782
,
-
23.2652
,
-
15.3966
,
-
22.8034
,
-
23.3159
,
-
15.5640
,
-
15.3970
,
-
15.4614
,
-
10.4370
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
GlideTextToImageUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
GlideTextToImageUNetModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
3
sizes
=
(
32
,
32
)
transformer_dim
=
32
seq_len
=
16
noise
=
torch
.
randn
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
emb
=
torch
.
randn
((
batch_size
,
seq_len
,
transformer_dim
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"transformer_out"
:
emb
}
@
property
def
get_input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
6
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"attention_resolutions"
:
(
2
,),
"channel_mult"
:
(
1
,
2
),
"in_channels"
:
3
,
"out_channels"
:
6
,
"model_channels"
:
32
,
"num_head_channels"
:
8
,
"num_heads_upsample"
:
1
,
"num_res_blocks"
:
2
,
"resblock_updown"
:
True
,
"resolution"
:
32
,
"use_scale_shift_norm"
:
True
,
"transformer_dim"
:
32
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_output
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
GlideTextToImageUNetModel
.
from_pretrained
(
"fusing/unet-glide-text2im-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
GlideTextToImageUNetModel
.
from_pretrained
(
"fusing/unet-glide-text2im-dummy"
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
((
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)).
to
(
torch_device
)
emb
=
torch
.
randn
((
1
,
16
,
model
.
config
.
transformer_dim
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
emb
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
2.7766
,
-
10.3558
,
-
14.9149
,
-
0.9376
,
-
14.9175
,
-
17.7679
,
-
5.5565
,
-
12.9521
,
-
12.9845
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
UNetLDMModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetLDMModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
4
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
get_input_shape
(
self
):
return
(
4
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
4
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"image_size"
:
32
,
"in_channels"
:
4
,
"out_channels"
:
4
,
"model_channels"
:
32
,
"num_res_blocks"
:
2
,
"attention_resolutions"
:
(
16
,),
"channel_mult"
:
(
1
,
2
),
"num_heads"
:
2
,
"conv_resample"
:
True
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetLDMModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetLDMModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
13.3258
,
-
20.1100
,
-
15.9873
,
-
17.6617
,
-
23.0596
,
-
17.9419
,
-
13.3675
,
-
16.1889
,
-
12.3800
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
UNetGradTTSModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetGradTTSModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_features
=
32
seq_len
=
16
noise
=
floats_tensor
((
batch_size
,
num_features
,
seq_len
)).
to
(
torch_device
)
condition
=
floats_tensor
((
batch_size
,
num_features
,
seq_len
)).
to
(
torch_device
)
mask
=
floats_tensor
((
batch_size
,
1
,
seq_len
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
batch_size
).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"mu"
:
condition
,
"mask"
:
mask
}
@
property
def
get_input_shape
(
self
):
return
(
4
,
32
,
16
)
@
property
def
get_output_shape
(
self
):
return
(
4
,
32
,
16
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"dim"
:
64
,
"groups"
:
4
,
"dim_mults"
:
(
1
,
2
),
"n_feats"
:
32
,
"pe_scale"
:
1000
,
"n_spks"
:
1
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetGradTTSModel
.
from_pretrained
(
"fusing/unet-grad-tts-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetGradTTSModel
.
from_pretrained
(
"fusing/unet-grad-tts-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
num_features
=
model
.
config
.
n_feats
seq_len
=
16
noise
=
torch
.
randn
((
1
,
num_features
,
seq_len
))
condition
=
torch
.
randn
((
1
,
num_features
,
seq_len
))
mask
=
torch
.
randn
((
1
,
1
,
seq_len
))
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
condition
,
mask
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.0690
,
-
0.0531
,
0.0633
,
-
0.0660
,
-
0.0541
,
0.0650
,
-
0.0656
,
-
0.0555
,
0.0617
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
PipelineTesterMixin
(
unittest
.
TestCase
):
def
test_from_pretrained_save_pretrained
(
self
):
# 1. Load models
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
schedular
=
DDPMScheduler
(
timesteps
=
10
)
ddpm
=
DDPMPipeline
(
model
,
schedular
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
ddpm
.
save_pretrained
(
tmpdirname
)
new_ddpm
=
DDPMPipeline
.
from_pretrained
(
tmpdirname
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
generator
=
generator
.
manual_seed
(
0
)
new_image
=
new_ddpm
(
generator
=
generator
)
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
@
slow
def
test_from_pretrained_hub
(
self
):
model_path
=
"fusing/ddpm-cifar10"
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
ddpm
.
noise_scheduler
.
num_timesteps
=
10
ddpm_from_hub
.
noise_scheduler
.
num_timesteps
=
10
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
generator
=
generator
.
manual_seed
(
0
)
new_image
=
ddpm_from_hub
(
generator
=
generator
)
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
@
slow
def
test_ddpm_cifar10
(
self
):
generator
=
torch
.
manual_seed
(
0
)
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
noise_scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
image
=
ddpm
(
generator
=
generator
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
([
0.2250
,
0.3375
,
0.2360
,
0.0930
,
0.3440
,
0.3156
,
0.1937
,
0.3585
,
0.1761
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_ddim_cifar10
(
self
):
generator
=
torch
.
manual_seed
(
0
)
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
noise_scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
(
[
-
0.7383
,
-
0.7385
,
-
0.7298
,
-
0.7364
,
-
0.7414
,
-
0.7239
,
-
0.6737
,
-
0.6813
,
-
0.7068
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_pndm_cifar10
(
self
):
generator
=
torch
.
manual_seed
(
0
)
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
noise_scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
image
=
pndm
(
generator
=
generator
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
(
[
-
0.7888
,
-
0.7870
,
-
0.7759
,
-
0.7823
,
-
0.8014
,
-
0.7608
,
-
0.6818
,
-
0.7130
,
-
0.7471
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_ldm_text2img
(
self
):
model_id
=
"fusing/latent-diffusion-text2im-large"
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
model_id
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
([
prompt
],
generator
=
generator
,
num_inference_steps
=
20
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
expected_slice
=
torch
.
tensor
([
0.7295
,
0.7358
,
0.7256
,
0.7435
,
0.7095
,
0.6884
,
0.7325
,
0.6921
,
0.6458
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_glide_text2img
(
self
):
model_id
=
"fusing/glide-base"
glide
=
GlidePipeline
.
from_pretrained
(
model_id
)
prompt
=
"a pencil sketch of a corgi"
generator
=
torch
.
manual_seed
(
0
)
image
=
glide
(
prompt
,
generator
=
generator
,
num_inference_steps_upscale
=
20
)
image_slice
=
image
[
0
,
:
3
,
:
3
,
-
1
].
cpu
()
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
torch
.
tensor
([
0.7119
,
0.7073
,
0.6460
,
0.7780
,
0.7423
,
0.6926
,
0.7378
,
0.7189
,
0.7784
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_grad_tts
(
self
):
model_id
=
"fusing/grad-tts-libri-tts"
grad_tts
=
GradTTSPipeline
.
from_pretrained
(
model_id
)
noise_scheduler
=
GradTTSScheduler
()
grad_tts
.
noise_scheduler
=
noise_scheduler
text
=
"Hello world, I missed you so much."
generator
=
torch
.
manual_seed
(
0
)
# generate mel spectograms using text
mel_spec
=
grad_tts
(
text
,
generator
=
generator
)
assert
mel_spec
.
shape
==
(
1
,
80
,
143
)
expected_slice
=
torch
.
tensor
(
[
-
6.7584
,
-
6.8347
,
-
6.3293
,
-
6.6437
,
-
6.7233
,
-
6.4684
,
-
6.1187
,
-
6.3172
,
-
6.6890
]
)
assert
(
mel_spec
[
0
,
:
3
,
:
3
].
cpu
().
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
def
test_module_from_pipeline
(
self
):
model
=
DiffWave
(
num_res_layers
=
4
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
bddm
=
BDDMPipeline
(
model
,
noise_scheduler
)
# check if the library name for the diffwave moduel is set to pipeline module
self
.
assertTrue
(
bddm
.
config
[
"diffwave"
][
0
]
==
"pipeline_bddm"
)
# check if we can save and load the pipeline
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
bddm
.
save_pretrained
(
tmpdirname
)
_
=
BDDMPipeline
.
from_pretrained
(
tmpdirname
)
# check if the same works using the DifusionPipeline class
_
=
DiffusionPipeline
.
from_pretrained
(
tmpdirname
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment