Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
02a76c2c
Commit
02a76c2c
authored
Jun 27, 2022
by
Patrick von Platen
Browse files
consolidate timestep embeds
parent
014ebc59
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
173 additions
and
853 deletions
+173
-853
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+84
-65
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+22
-21
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+23
-22
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+22
-30
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+12
-11
tests/test_layers_utils.py
tests/test_layers_utils.py
+10
-704
No files found.
src/diffusers/models/embeddings.py
View file @
02a76c2c
...
@@ -11,49 +11,104 @@
...
@@ -11,49 +11,104 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
import
math
import
numpy
as
np
from
torch
import
nn
import
torch.nn.functional
as
F
# unet.py
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
False
,
downscale_freq_shift
=
1
,
max_period
=
10000
):
"""
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
Create sinusoidal timestep embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
"""
assert
len
(
timesteps
.
shape
)
==
1
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
/
(
embedding_dim
//
2
-
downscale_freq_shift
)
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
timesteps
[:,
None
].
float
()
*
emb
[
None
,
:]
# concat sine and cosine embeddings
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
# flip sine and cosine embeddings
if
flip_sin_to_cos
:
emb
=
torch
.
cat
([
emb
[:,
half_dim
:],
emb
[:,
:
half_dim
]],
dim
=-
1
)
# zero pad
if
embedding_dim
%
2
==
1
:
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
return
emb
# unet_glide.py
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
#def get_timestep_embedding(timesteps, embedding_dim):
These may be fractional.
# """
:param dim: the dimension of the output.
# This matches the implementation in Denoising Diffusion Probabilistic Models:
:param max_period: controls the minimum frequency of the embeddings.
# From Fairseq.
:return: an [N x dim] Tensor of positional embeddings.
# Build sinusoidal embeddings.
"""
# This matches the implementation in tensor2tensor, but differs slightly
half
=
dim
//
2
# from the description in Section 3.5 of "Attention Is All You Need".
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
# """
device
=
timesteps
.
device
# assert len(timesteps.shape) == 1
)
#
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
# half_dim = embedding_dim // 2
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
# emb = math.log(10000) / (half_dim - 1)
if
dim
%
2
:
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
# emb = emb.to(device=timesteps.device)
return
embedding
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
#def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None, :]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
#def a_get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
# half_dim = embedding_dim // 2
# magic number 10000 is from transformers
# emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = F.pad(emb, (0, 1), mode="constant")
# assert emb.shape == (timesteps.shape[0], embedding_dim)
# return emb
# unet_grad_tts.py
# unet_grad_tts.py
class
SinusoidalPosEmb
(
torch
.
nn
.
Module
):
class
SinusoidalPosEmb
(
torch
.
nn
.
Module
):
...
@@ -70,26 +125,6 @@ class SinusoidalPosEmb(torch.nn.Module):
...
@@ -70,26 +125,6 @@ class SinusoidalPosEmb(torch.nn.Module):
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
return
emb
# unet_ldm.py
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
# unet_rl.py
# unet_rl.py
class
SinusoidalPosEmb
(
nn
.
Module
):
class
SinusoidalPosEmb
(
nn
.
Module
):
...
@@ -106,22 +141,6 @@ class SinusoidalPosEmb(nn.Module):
...
@@ -106,22 +141,6 @@ class SinusoidalPosEmb(nn.Module):
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
return
emb
# unet_sde_score_estimation.py
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
max_positions
=
10000
):
assert
len
(
timesteps
.
shape
)
==
1
# and timesteps.dtype == tf.int32
half_dim
=
embedding_dim
//
2
# magic number 10000 is from transformers
emb
=
math
.
log
(
max_positions
)
/
(
half_dim
-
1
)
# emb = math.log(2.) / (half_dim - 1)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
,
device
=
timesteps
.
device
)
*
-
emb
)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
F
.
pad
(
emb
,
(
0
,
1
),
mode
=
"constant"
)
assert
emb
.
shape
==
(
timesteps
.
shape
[
0
],
embedding_dim
)
return
emb
# unet_sde_score_estimation.py
# unet_sde_score_estimation.py
class
GaussianFourierProjection
(
nn
.
Module
):
class
GaussianFourierProjection
(
nn
.
Module
):
...
...
src/diffusers/models/unet.py
View file @
02a76c2c
...
@@ -30,27 +30,28 @@ from tqdm import tqdm
...
@@ -30,27 +30,28 @@ from tqdm import tqdm
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
#def get_timestep_embedding(timesteps, embedding_dim):
This matches the implementation in Denoising Diffusion Probabilistic Models:
# """
From Fairseq.
# This matches the implementation in Denoising Diffusion Probabilistic Models:
Build sinusoidal embeddings.
# From Fairseq.
This matches the implementation in tensor2tensor, but differs slightly
# Build sinusoidal embeddings.
from the description in Section 3.5 of "Attention Is All You Need".
# This matches the implementation in tensor2tensor, but differs slightly
"""
# from the description in Section 3.5 of "Attention Is All You Need".
assert
len
(
timesteps
.
shape
)
==
1
# """
# assert len(timesteps.shape) == 1
half_dim
=
embedding_dim
//
2
#
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
# half_dim = embedding_dim // 2
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
# emb = math.log(10000) / (half_dim - 1)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
# emb = emb.to(device=timesteps.device)
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
# emb = timesteps.float()[:, None] * emb[None, :]
if
embedding_dim
%
2
==
1
:
# zero pad
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
# if embedding_dim % 2 == 1: # zero pad
return
emb
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
# return emb
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
...
src/diffusers/models/unet_glide.py
View file @
02a76c2c
...
@@ -7,6 +7,7 @@ import torch.nn.functional as F
...
@@ -7,6 +7,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -86,25 +87,25 @@ def normalization(channels, swish=0.0):
...
@@ -86,25 +87,25 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
#
def timestep_embedding(timesteps, dim, max_period=10000):
"""
#
"""
Create sinusoidal timestep embeddings.
#
Create sinusoidal timestep embeddings.
#
:param timesteps: a 1-D Tensor of N indices, one per batch element.
#
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
#
These may be fractional.
:param dim: the dimension of the output.
#
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
#
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
#
:return: an [N x dim] Tensor of positional embeddings.
"""
#
"""
half
=
dim
//
2
#
half = dim // 2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
#
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device
=
timesteps
.
device
#
device=timesteps.device
)
#
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
#
args = timesteps[:, None].float() * freqs[None]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
#
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if
dim
%
2
:
#
if dim % 2:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
#
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return
embedding
#
return embedding
def
zero_module
(
module
):
def
zero_module
(
module
):
...
@@ -627,7 +628,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -627,7 +628,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
"""
"""
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
h
=
x
.
type
(
self
.
dtype
)
h
=
x
.
type
(
self
.
dtype
)
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
...
@@ -714,7 +715,7 @@ class GlideTextToImageUNetModel(GlideUNetModel):
...
@@ -714,7 +715,7 @@ class GlideTextToImageUNetModel(GlideUNetModel):
def
forward
(
self
,
x
,
timesteps
,
transformer_out
=
None
):
def
forward
(
self
,
x
,
timesteps
,
transformer_out
=
None
):
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
# project the last token
# project the last token
transformer_proj
=
self
.
transformer_proj
(
transformer_out
[:,
-
1
])
transformer_proj
=
self
.
transformer_proj
(
transformer_out
[:,
-
1
])
...
@@ -806,7 +807,7 @@ class GlideSuperResUNetModel(GlideUNetModel):
...
@@ -806,7 +807,7 @@ class GlideSuperResUNetModel(GlideUNetModel):
x
=
torch
.
cat
([
x
,
upsampled
],
dim
=
1
)
x
=
torch
.
cat
([
x
,
upsampled
],
dim
=
1
)
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
h
=
x
h
=
x
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
...
...
src/diffusers/models/unet_ldm.py
View file @
02a76c2c
...
@@ -16,6 +16,7 @@ except:
...
@@ -16,6 +16,7 @@ except:
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
def
exists
(
val
):
def
exists
(
val
):
...
@@ -316,34 +317,25 @@ def normalization(channels, swish=0.0):
...
@@ -316,34 +317,25 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
#def timestep_embedding(timesteps, dim, max_period=10000):
"""
# """
Create sinusoidal timestep embeddings.
# Create sinusoidal timestep embeddings.
#
:param timesteps: a 1-D Tensor of N indices, one per batch element.
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
# These may be fractional.
:param dim: the dimension of the output.
# :param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
# :param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
"""
# """
half
=
dim
//
2
# half = dim // 2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device
=
timesteps
.
device
# device=timesteps.device
)
# )
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
# args = timesteps[:, None].float() * freqs[None]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if
dim
%
2
:
# if dim % 2:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return
embedding
# return embedding
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
## go
## go
...
@@ -1026,7 +1018,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -1026,7 +1018,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
hs
=
[]
hs
=
[]
if
not
torch
.
is_tensor
(
timesteps
):
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
x
.
device
)
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
x
.
device
)
t_emb
=
timestep_embedding
(
timesteps
,
self
.
model_channels
)
t_emb
=
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
emb
=
self
.
time_embed
(
t_emb
)
emb
=
self
.
time_embed
(
t_emb
)
if
self
.
num_classes
is
not
None
:
if
self
.
num_classes
is
not
None
:
...
@@ -1240,7 +1232,7 @@ class EncoderUNetModel(nn.Module):
...
@@ -1240,7 +1232,7 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
:return: an [N x K] Tensor of outputs.
"""
"""
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
results
=
[]
results
=
[]
h
=
x
.
type
(
self
.
dtype
)
h
=
x
.
type
(
self
.
dtype
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
02a76c2c
...
@@ -26,6 +26,7 @@ import torch.nn.functional as F
...
@@ -26,6 +26,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
...
@@ -381,21 +382,21 @@ def get_act(nonlinearity):
...
@@ -381,21 +382,21 @@ def get_act(nonlinearity):
raise
NotImplementedError
(
"activation function does not exist!"
)
raise
NotImplementedError
(
"activation function does not exist!"
)
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
max_positions
=
10000
):
#
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert
len
(
timesteps
.
shape
)
==
1
# and timesteps.dtype == tf.int32
#
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim
=
embedding_dim
//
2
#
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
# magic number 10000 is from transformers
emb
=
math
.
log
(
max_positions
)
/
(
half_dim
-
1
)
#
emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
,
device
=
timesteps
.
device
)
*
-
emb
)
#
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
#
emb = timesteps.float()[:, None] * emb[None, :]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
#
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if
embedding_dim
%
2
==
1
:
# zero pad
#
if embedding_dim % 2 == 1: # zero pad
emb
=
F
.
pad
(
emb
,
(
0
,
1
),
mode
=
"constant"
)
#
emb = F.pad(emb, (0, 1), mode="constant")
assert
emb
.
shape
==
(
timesteps
.
shape
[
0
],
embedding_dim
)
#
assert emb.shape == (timesteps.shape[0], embedding_dim)
return
emb
#
return emb
def
default_init
(
scale
=
1.0
):
def
default_init
(
scale
=
1.0
):
...
...
tests/test_layers_utils.py
View file @
02a76c2c
...
@@ -21,718 +21,24 @@ import unittest
...
@@ -21,718 +21,24 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
diffusers
import
(
#from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding, a_get_timestep_embedding
BDDMPipeline
,
from
diffusers.models.embeddings
import
get_timestep_embedding
,
timestep_embedding
DDIMPipeline
,
DDIMScheduler
,
DDPMPipeline
,
DDPMScheduler
,
GlidePipeline
,
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GradTTSPipeline
,
GradTTSScheduler
,
LatentDiffusionPipeline
,
PNDMPipeline
,
PNDMScheduler
,
UNetGradTTSModel
,
UNetLDMModel
,
UNetModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipelines.pipeline_bddm
import
DiffWave
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
class
ConfigTester
(
unittest
.
TestCase
):
class
EmbeddingsTests
(
unittest
.
TestCase
):
def
test_load_not_from_mixin
(
self
):
with
self
.
assertRaises
(
ValueError
):
ConfigMixin
.
from_config
(
"dummy_path"
)
def
test_
save_load
(
self
):
def
test_
timestep_embeddings
(
self
):
class
SampleObject
(
ConfigMixin
):
embedding_dim
=
16
config_name
=
"config.json"
timesteps
=
torch
.
arange
(
10
)
def
__init__
(
t1
=
get_timestep_embedding
(
timesteps
,
embedding_dim
)
self
,
t2
=
timestep_embedding
(
timesteps
,
embedding_dim
)
a
=
2
,
t3
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
True
,
downscale_freq_factor
=
8
)
b
=
5
,
c
=
(
2
,
5
),
d
=
"for diffusion"
,
e
=
[
1
,
3
],
):
self
.
register_to_config
(
a
=
a
,
b
=
b
,
c
=
c
,
d
=
d
,
e
=
e
)
obj
=
SampleObject
()
import
ipdb
;
ipdb
.
set_trace
()
config
=
obj
.
config
assert
config
[
"a"
]
==
2
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
(
2
,
5
)
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
new_obj
=
SampleObject
.
from_config
(
tmpdirname
)
new_config
=
new_obj
.
config
# unfreeze configs
config
=
dict
(
config
)
new_config
=
dict
(
new_config
)
assert
config
.
pop
(
"c"
)
==
(
2
,
5
)
# instantiated as tuple
assert
new_config
.
pop
(
"c"
)
==
[
2
,
5
]
# saved & loaded as list because of json
assert
config
==
new_config
class
ModelTesterMixin
:
def
test_from_pretrained_save_pretrained
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_pretrained
(
tmpdirname
)
new_model
.
to
(
torch_device
)
with
torch
.
no_grad
():
image
=
model
(
**
inputs_dict
)
new_image
=
new_model
(
**
inputs_dict
)
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
self
.
assertLessEqual
(
max_diff
,
1e-5
,
"Models give different forward passes"
)
def
test_determinism
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
first
=
model
(
**
inputs_dict
)
second
=
model
(
**
inputs_dict
)
out_1
=
first
.
cpu
().
numpy
()
out_2
=
second
.
cpu
().
numpy
()
out_1
=
out_1
[
~
np
.
isnan
(
out_1
)]
out_2
=
out_2
[
~
np
.
isnan
(
out_2
)]
max_diff
=
np
.
amax
(
np
.
abs
(
out_1
-
out_2
))
self
.
assertLessEqual
(
max_diff
,
1e-5
)
def
test_output
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_forward_signature
(
self
):
init_dict
,
_
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
signature
=
inspect
.
signature
(
model
.
forward
)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names
=
[
*
signature
.
parameters
.
keys
()]
expected_arg_names
=
[
"x"
,
"timesteps"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_config
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
# test if the model can be loaded from the config
# and has all the expected shape
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_config
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_config
(
tmpdirname
)
new_model
.
to
(
torch_device
)
new_model
.
eval
()
# check if all paramters shape are the same
for
param_name
in
model
.
state_dict
().
keys
():
param_1
=
model
.
state_dict
()[
param_name
]
param_2
=
new_model
.
state_dict
()[
param_name
]
self
.
assertEqual
(
param_1
.
shape
,
param_2
.
shape
)
with
torch
.
no_grad
():
output_1
=
model
(
**
inputs_dict
)
output_2
=
new_model
(
**
inputs_dict
)
self
.
assertEqual
(
output_1
.
shape
,
output_2
.
shape
)
def
test_training
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
train
()
output
=
model
(
**
inputs_dict
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],)
+
self
.
get_output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
class
UnetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
3
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
get_input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"ch"
:
32
,
"ch_mult"
:
(
1
,
2
),
"num_res_blocks"
:
2
,
"attn_resolutions"
:
(
16
,),
"resolution"
:
32
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
GlideSuperResUNetModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
6
sizes
=
(
32
,
32
)
low_res_size
=
(
4
,
4
)
noise
=
torch
.
randn
((
batch_size
,
num_channels
//
2
)
+
sizes
).
to
(
torch_device
)
low_res
=
torch
.
randn
((
batch_size
,
3
)
+
low_res_size
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"low_res"
:
low_res
}
@
property
def
get_input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
6
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"attention_resolutions"
:
(
2
,),
"channel_mult"
:
(
1
,
2
),
"in_channels"
:
6
,
"out_channels"
:
6
,
"model_channels"
:
32
,
"num_head_channels"
:
8
,
"num_heads_upsample"
:
1
,
"num_res_blocks"
:
2
,
"resblock_updown"
:
True
,
"resolution"
:
32
,
"use_scale_shift_norm"
:
True
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_output
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
GlideSuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
GlideSuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
3
,
64
,
64
)
low_res
=
torch
.
randn
(
1
,
3
,
4
,
4
)
time_step
=
torch
.
tensor
([
42
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
low_res
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
22.8782
,
-
23.2652
,
-
15.3966
,
-
22.8034
,
-
23.3159
,
-
15.5640
,
-
15.3970
,
-
15.4614
,
-
10.4370
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
GlideTextToImageUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
GlideTextToImageUNetModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
3
sizes
=
(
32
,
32
)
transformer_dim
=
32
seq_len
=
16
noise
=
torch
.
randn
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
emb
=
torch
.
randn
((
batch_size
,
seq_len
,
transformer_dim
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"transformer_out"
:
emb
}
@
property
def
get_input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
6
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"attention_resolutions"
:
(
2
,),
"channel_mult"
:
(
1
,
2
),
"in_channels"
:
3
,
"out_channels"
:
6
,
"model_channels"
:
32
,
"num_head_channels"
:
8
,
"num_heads_upsample"
:
1
,
"num_res_blocks"
:
2
,
"resblock_updown"
:
True
,
"resolution"
:
32
,
"use_scale_shift_norm"
:
True
,
"transformer_dim"
:
32
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_output
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
GlideTextToImageUNetModel
.
from_pretrained
(
"fusing/unet-glide-text2im-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
GlideTextToImageUNetModel
.
from_pretrained
(
"fusing/unet-glide-text2im-dummy"
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
((
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)).
to
(
torch_device
)
emb
=
torch
.
randn
((
1
,
16
,
model
.
config
.
transformer_dim
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
emb
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
2.7766
,
-
10.3558
,
-
14.9149
,
-
0.9376
,
-
14.9175
,
-
17.7679
,
-
5.5565
,
-
12.9521
,
-
12.9845
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
UNetLDMModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetLDMModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
4
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
get_input_shape
(
self
):
return
(
4
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
4
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"image_size"
:
32
,
"in_channels"
:
4
,
"out_channels"
:
4
,
"model_channels"
:
32
,
"num_res_blocks"
:
2
,
"attention_resolutions"
:
(
16
,),
"channel_mult"
:
(
1
,
2
),
"num_heads"
:
2
,
"conv_resample"
:
True
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetLDMModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetLDMModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
13.3258
,
-
20.1100
,
-
15.9873
,
-
17.6617
,
-
23.0596
,
-
17.9419
,
-
13.3675
,
-
16.1889
,
-
12.3800
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
UNetGradTTSModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetGradTTSModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_features
=
32
seq_len
=
16
noise
=
floats_tensor
((
batch_size
,
num_features
,
seq_len
)).
to
(
torch_device
)
condition
=
floats_tensor
((
batch_size
,
num_features
,
seq_len
)).
to
(
torch_device
)
mask
=
floats_tensor
((
batch_size
,
1
,
seq_len
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
batch_size
).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"mu"
:
condition
,
"mask"
:
mask
}
@
property
def
get_input_shape
(
self
):
return
(
4
,
32
,
16
)
@
property
def
get_output_shape
(
self
):
return
(
4
,
32
,
16
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"dim"
:
64
,
"groups"
:
4
,
"dim_mults"
:
(
1
,
2
),
"n_feats"
:
32
,
"pe_scale"
:
1000
,
"n_spks"
:
1
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetGradTTSModel
.
from_pretrained
(
"fusing/unet-grad-tts-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetGradTTSModel
.
from_pretrained
(
"fusing/unet-grad-tts-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
num_features
=
model
.
config
.
n_feats
seq_len
=
16
noise
=
torch
.
randn
((
1
,
num_features
,
seq_len
))
condition
=
torch
.
randn
((
1
,
num_features
,
seq_len
))
mask
=
torch
.
randn
((
1
,
1
,
seq_len
))
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
condition
,
mask
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.0690
,
-
0.0531
,
0.0633
,
-
0.0660
,
-
0.0541
,
0.0650
,
-
0.0656
,
-
0.0555
,
0.0617
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
PipelineTesterMixin
(
unittest
.
TestCase
):
def
test_from_pretrained_save_pretrained
(
self
):
# 1. Load models
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
schedular
=
DDPMScheduler
(
timesteps
=
10
)
ddpm
=
DDPMPipeline
(
model
,
schedular
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
ddpm
.
save_pretrained
(
tmpdirname
)
new_ddpm
=
DDPMPipeline
.
from_pretrained
(
tmpdirname
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
generator
=
generator
.
manual_seed
(
0
)
new_image
=
new_ddpm
(
generator
=
generator
)
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
@
slow
def
test_from_pretrained_hub
(
self
):
model_path
=
"fusing/ddpm-cifar10"
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
ddpm
.
noise_scheduler
.
num_timesteps
=
10
ddpm_from_hub
.
noise_scheduler
.
num_timesteps
=
10
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
generator
=
generator
.
manual_seed
(
0
)
new_image
=
ddpm_from_hub
(
generator
=
generator
)
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
@
slow
def
test_ddpm_cifar10
(
self
):
generator
=
torch
.
manual_seed
(
0
)
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
noise_scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
image
=
ddpm
(
generator
=
generator
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
([
0.2250
,
0.3375
,
0.2360
,
0.0930
,
0.3440
,
0.3156
,
0.1937
,
0.3585
,
0.1761
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_ddim_cifar10
(
self
):
generator
=
torch
.
manual_seed
(
0
)
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
noise_scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
(
[
-
0.7383
,
-
0.7385
,
-
0.7298
,
-
0.7364
,
-
0.7414
,
-
0.7239
,
-
0.6737
,
-
0.6813
,
-
0.7068
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_pndm_cifar10
(
self
):
generator
=
torch
.
manual_seed
(
0
)
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
noise_scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
image
=
pndm
(
generator
=
generator
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
(
[
-
0.7888
,
-
0.7870
,
-
0.7759
,
-
0.7823
,
-
0.8014
,
-
0.7608
,
-
0.6818
,
-
0.7130
,
-
0.7471
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_ldm_text2img
(
self
):
model_id
=
"fusing/latent-diffusion-text2im-large"
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
model_id
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
([
prompt
],
generator
=
generator
,
num_inference_steps
=
20
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
expected_slice
=
torch
.
tensor
([
0.7295
,
0.7358
,
0.7256
,
0.7435
,
0.7095
,
0.6884
,
0.7325
,
0.6921
,
0.6458
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_glide_text2img
(
self
):
model_id
=
"fusing/glide-base"
glide
=
GlidePipeline
.
from_pretrained
(
model_id
)
prompt
=
"a pencil sketch of a corgi"
generator
=
torch
.
manual_seed
(
0
)
image
=
glide
(
prompt
,
generator
=
generator
,
num_inference_steps_upscale
=
20
)
image_slice
=
image
[
0
,
:
3
,
:
3
,
-
1
].
cpu
()
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
torch
.
tensor
([
0.7119
,
0.7073
,
0.6460
,
0.7780
,
0.7423
,
0.6926
,
0.7378
,
0.7189
,
0.7784
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_grad_tts
(
self
):
model_id
=
"fusing/grad-tts-libri-tts"
grad_tts
=
GradTTSPipeline
.
from_pretrained
(
model_id
)
noise_scheduler
=
GradTTSScheduler
()
grad_tts
.
noise_scheduler
=
noise_scheduler
text
=
"Hello world, I missed you so much."
generator
=
torch
.
manual_seed
(
0
)
# generate mel spectograms using text
mel_spec
=
grad_tts
(
text
,
generator
=
generator
)
assert
mel_spec
.
shape
==
(
1
,
80
,
143
)
expected_slice
=
torch
.
tensor
(
[
-
6.7584
,
-
6.8347
,
-
6.3293
,
-
6.6437
,
-
6.7233
,
-
6.4684
,
-
6.1187
,
-
6.3172
,
-
6.6890
]
)
assert
(
mel_spec
[
0
,
:
3
,
:
3
].
cpu
().
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
def
test_module_from_pipeline
(
self
):
model
=
DiffWave
(
num_res_layers
=
4
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
bddm
=
BDDMPipeline
(
model
,
noise_scheduler
)
# check if the library name for the diffwave moduel is set to pipeline module
self
.
assertTrue
(
bddm
.
config
[
"diffwave"
][
0
]
==
"pipeline_bddm"
)
# check if we can save and load the pipeline
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
bddm
.
save_pretrained
(
tmpdirname
)
_
=
BDDMPipeline
.
from_pretrained
(
tmpdirname
)
# check if the same works using the DifusionPipeline class
_
=
DiffusionPipeline
.
from_pretrained
(
tmpdirname
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment