Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c7a39d38
Commit
c7a39d38
authored
Jun 27, 2022
by
Patrick von Platen
Browse files
refactor all sinus embeddings
parent
02a76c2c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
117 additions
and
204 deletions
+117
-204
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+25
-94
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+0
-21
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+9
-24
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+3
-20
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+3
-22
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+0
-17
tests/test_layers_utils.py
tests/test_layers_utils.py
+77
-6
No files found.
src/diffusers/models/embeddings.py
View file @
c7a39d38
...
@@ -11,15 +11,16 @@
...
@@ -11,15 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
torch
from
torch
import
nn
from
torch
import
nn
import
torch.nn.functional
as
F
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
False
,
downscale_freq_shift
=
1
,
max_period
=
10000
):
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
False
,
downscale_freq_shift
=
1
,
scale
=
1
,
max_period
=
10000
):
"""
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
This matches the implementation in Denoising Diffusion Probabilistic Models:
...
@@ -31,18 +32,22 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
...
@@ -31,18 +32,22 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
:param max_period: controls the minimum frequency of the embeddings.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
"""
assert
len
(
timesteps
.
shape
)
==
1
assert
len
(
timesteps
.
shape
)
==
1
,
"Timesteps should be a 1d-array"
half_dim
=
embedding_dim
//
2
half_dim
=
embedding_dim
//
2
emb
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
/
(
embedding_dim
//
2
-
downscale_freq_shift
))
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb_coeff
=
-
math
.
log
(
max_period
)
/
(
half_dim
-
downscale_freq_shift
)
emb
=
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
,
device
=
timesteps
.
device
)
emb
=
torch
.
exp
(
emb
*
emb_coeff
)
emb
=
timesteps
[:,
None
].
float
()
*
emb
[
None
,
:]
emb
=
timesteps
[:,
None
].
float
()
*
emb
[
None
,
:]
# scale embeddings
emb
=
scale
*
emb
# concat sine and cosine embeddings
# concat sine and cosine embeddings
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
-
1
)
# flip sine and cosine embeddings
# flip sine and cosine embeddings
if
flip_sin_to_cos
:
if
flip_sin_to_cos
:
emb
=
torch
.
cat
([
emb
[:,
half_dim
:],
emb
[:,
:
half_dim
]],
dim
=-
1
)
emb
=
torch
.
cat
([
emb
[:,
half_dim
:],
emb
[:,
:
half_dim
]],
dim
=-
1
)
...
@@ -52,81 +57,20 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
...
@@ -52,81 +57,20 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
return
emb
return
emb
#def get_timestep_embedding(timesteps, embedding_dim):
# unet_sde_score_estimation.py
# """
class
GaussianFourierProjection
(
nn
.
Module
):
# This matches the implementation in Denoising Diffusion Probabilistic Models:
"""Gaussian Fourier embeddings for noise levels."""
# From Fairseq.
# Build sinusoidal embeddings.
# This matches the implementation in tensor2tensor, but differs slightly
# from the description in Section 3.5 of "Attention Is All You Need".
# """
# assert len(timesteps.shape) == 1
#
# half_dim = embedding_dim // 2
# emb = math.log(10000) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
# emb = emb.to(device=timesteps.device)
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
#def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None, :]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
#def a_get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
# half_dim = embedding_dim // 2
# magic number 10000 is from transformers
# emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = F.pad(emb, (0, 1), mode="constant")
# assert emb.shape == (timesteps.shape[0], embedding_dim)
# return emb
# unet_grad_tts.py
class
SinusoidalPosEmb
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
SinusoidalPosEmb
,
self
).
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
,
scale
=
1000
):
def
__init__
(
self
,
embedding_size
=
256
,
scale
=
1.0
):
device
=
x
.
device
super
().
__init__
()
half_dim
=
self
.
dim
//
2
self
.
W
=
nn
.
Parameter
(
torch
.
randn
(
embedding_size
)
*
scale
,
requires_grad
=
False
)
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
).
float
()
*
-
emb
)
def
forward
(
self
,
x
):
emb
=
scale
*
x
.
unsqueeze
(
1
)
*
emb
.
unsqueeze
(
0
)
x_proj
=
x
[:,
None
]
*
self
.
W
[
None
,
:]
*
2
*
np
.
pi
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
torch
.
cat
([
torch
.
sin
(
x_proj
),
torch
.
cos
(
x_proj
)],
dim
=-
1
)
return
emb
# unet_rl.py
# unet_rl.py
- TODO(need test)
class
SinusoidalPosEmb
(
nn
.
Module
):
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
super
().
__init__
()
...
@@ -140,16 +84,3 @@ class SinusoidalPosEmb(nn.Module):
...
@@ -140,16 +84,3 @@ class SinusoidalPosEmb(nn.Module):
emb
=
x
[:,
None
]
*
emb
[
None
,
:]
emb
=
x
[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
return
emb
# unet_sde_score_estimation.py
class
GaussianFourierProjection
(
nn
.
Module
):
"""Gaussian Fourier embeddings for noise levels."""
def
__init__
(
self
,
embedding_size
=
256
,
scale
=
1.0
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
torch
.
randn
(
embedding_size
)
*
scale
,
requires_grad
=
False
)
def
forward
(
self
,
x
):
x_proj
=
x
[:,
None
]
*
self
.
W
[
None
,
:]
*
2
*
np
.
pi
return
torch
.
cat
([
torch
.
sin
(
x_proj
),
torch
.
cos
(
x_proj
)],
dim
=-
1
)
src/diffusers/models/unet.py
View file @
c7a39d38
...
@@ -33,27 +33,6 @@ from ..modeling_utils import ModelMixin
...
@@ -33,27 +33,6 @@ from ..modeling_utils import ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
#def get_timestep_embedding(timesteps, embedding_dim):
# """
# This matches the implementation in Denoising Diffusion Probabilistic Models:
# From Fairseq.
# Build sinusoidal embeddings.
# This matches the implementation in tensor2tensor, but differs slightly
# from the description in Section 3.5 of "Attention Is All You Need".
# """
# assert len(timesteps.shape) == 1
#
# half_dim = embedding_dim // 2
# emb = math.log(10000) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
# emb = emb.to(device=timesteps.device)
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
# return emb
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
# swish
# swish
return
x
*
torch
.
sigmoid
(
x
)
return
x
*
torch
.
sigmoid
(
x
)
...
...
src/diffusers/models/unet_glide.py
View file @
c7a39d38
...
@@ -87,27 +87,6 @@ def normalization(channels, swish=0.0):
...
@@ -87,27 +87,6 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
# def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
def
zero_module
(
module
):
def
zero_module
(
module
):
"""
"""
Zero out the parameters of a module and return it.
Zero out the parameters of a module and return it.
...
@@ -628,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -628,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
"""
"""
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
)
h
=
x
.
type
(
self
.
dtype
)
h
=
x
.
type
(
self
.
dtype
)
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
...
@@ -715,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel):
...
@@ -715,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel):
def
forward
(
self
,
x
,
timesteps
,
transformer_out
=
None
):
def
forward
(
self
,
x
,
timesteps
,
transformer_out
=
None
):
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
)
# project the last token
# project the last token
transformer_proj
=
self
.
transformer_proj
(
transformer_out
[:,
-
1
])
transformer_proj
=
self
.
transformer_proj
(
transformer_out
[:,
-
1
])
...
@@ -807,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel):
...
@@ -807,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel):
x
=
torch
.
cat
([
x
,
upsampled
],
dim
=
1
)
x
=
torch
.
cat
([
x
,
upsampled
],
dim
=
1
)
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
)
h
=
x
h
=
x
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
...
...
src/diffusers/models/unet_grad_tts.py
View file @
c7a39d38
import
math
import
torch
import
torch
...
@@ -11,6 +9,7 @@ except:
...
@@ -11,6 +9,7 @@ except:
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
class
Mish
(
torch
.
nn
.
Module
):
class
Mish
(
torch
.
nn
.
Module
):
...
@@ -107,21 +106,6 @@ class Residual(torch.nn.Module):
...
@@ -107,21 +106,6 @@ class Residual(torch.nn.Module):
return
output
return
output
class
SinusoidalPosEmb
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
SinusoidalPosEmb
,
self
).
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
,
scale
=
1000
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
).
float
()
*
-
emb
)
emb
=
scale
*
x
.
unsqueeze
(
1
)
*
emb
.
unsqueeze
(
0
)
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
class
UNetGradTTSModel
(
ModelMixin
,
ConfigMixin
):
class
UNetGradTTSModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
dim
,
dim_mults
=
(
1
,
2
,
4
),
groups
=
8
,
n_spks
=
None
,
spk_emb_dim
=
64
,
n_feats
=
80
,
pe_scale
=
1000
):
def
__init__
(
self
,
dim
,
dim_mults
=
(
1
,
2
,
4
),
groups
=
8
,
n_spks
=
None
,
spk_emb_dim
=
64
,
n_feats
=
80
,
pe_scale
=
1000
):
super
(
UNetGradTTSModel
,
self
).
__init__
()
super
(
UNetGradTTSModel
,
self
).
__init__
()
...
@@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
)
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
)
)
)
self
.
time_pos_emb
=
SinusoidalPosEmb
(
dim
)
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
dim
,
dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
dim
*
4
,
dim
))
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
dim
,
dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
dim
*
4
,
dim
))
dims
=
[
2
+
(
1
if
n_spks
>
1
else
0
),
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
dims
=
[
2
+
(
1
if
n_spks
>
1
else
0
),
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
...
@@ -198,8 +181,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -198,8 +181,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if
not
isinstance
(
spk
,
type
(
None
)):
if
not
isinstance
(
spk
,
type
(
None
)):
s
=
self
.
spk_mlp
(
spk
)
s
=
self
.
spk_mlp
(
spk
)
t
=
get_timestep_embedding
(
timesteps
,
self
.
dim
,
scale
=
self
.
pe_scale
)
t
=
self
.
time_pos_emb
(
timesteps
,
scale
=
self
.
pe_scale
)
t
=
self
.
mlp
(
t
)
t
=
self
.
mlp
(
t
)
if
self
.
n_spks
<
2
:
if
self
.
n_spks
<
2
:
...
...
src/diffusers/models/unet_ldm.py
View file @
c7a39d38
...
@@ -317,27 +317,6 @@ def normalization(channels, swish=0.0):
...
@@ -317,27 +317,6 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
#def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
## go
## go
class
AttentionPool2d
(
nn
.
Module
):
class
AttentionPool2d
(
nn
.
Module
):
"""
"""
...
@@ -1232,7 +1211,9 @@ class EncoderUNetModel(nn.Module):
...
@@ -1232,7 +1211,9 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
:return: an [N x K] Tensor of outputs.
"""
"""
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
)
results
=
[]
results
=
[]
h
=
x
.
type
(
self
.
dtype
)
h
=
x
.
type
(
self
.
dtype
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
c7a39d38
...
@@ -382,23 +382,6 @@ def get_act(nonlinearity):
...
@@ -382,23 +382,6 @@ def get_act(nonlinearity):
raise
NotImplementedError
(
"activation function does not exist!"
)
raise
NotImplementedError
(
"activation function does not exist!"
)
#def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
# half_dim = embedding_dim // 2
# magic number 10000 is from transformers
# emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = F.pad(emb, (0, 1), mode="constant")
# assert emb.shape == (timesteps.shape[0], embedding_dim)
# return emb
def
default_init
(
scale
=
1.0
):
def
default_init
(
scale
=
1.0
):
"""The same initialization used in DDPM."""
"""The same initialization used in DDPM."""
scale
=
1e-10
if
scale
==
0
else
scale
scale
=
1e-10
if
scale
==
0
else
scale
...
...
tests/test_layers_utils.py
View file @
c7a39d38
...
@@ -21,8 +21,7 @@ import unittest
...
@@ -21,8 +21,7 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
#from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding, a_get_timestep_embedding
from
diffusers.models.embeddings
import
get_timestep_embedding
from
diffusers.models.embeddings
import
get_timestep_embedding
,
timestep_embedding
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
...
@@ -30,15 +29,87 @@ torch.backends.cuda.matmul.allow_tf32 = False
...
@@ -30,15 +29,87 @@ torch.backends.cuda.matmul.allow_tf32 = False
class
EmbeddingsTests
(
unittest
.
TestCase
):
class
EmbeddingsTests
(
unittest
.
TestCase
):
def
test_timestep_embeddings
(
self
):
def
test_timestep_embeddings
(
self
):
embedding_dim
=
256
timesteps
=
torch
.
arange
(
16
)
t1
=
get_timestep_embedding
(
timesteps
,
embedding_dim
)
# first vector should always be composed only of 0's and 1's
assert
(
t1
[
0
,
:
embedding_dim
//
2
]
-
0
).
abs
().
sum
()
<
1e-5
assert
(
t1
[
0
,
embedding_dim
//
2
:]
-
1
).
abs
().
sum
()
<
1e-5
# last element of each vector should be one
assert
(
t1
[:,
-
1
]
-
1
).
abs
().
sum
()
<
1e-5
# For large embeddings (e.g. 128) the frequency of every vector is higher
# than the previous one which means that the gradients of later vectors are
# ALWAYS higher than the previous ones
grad_mean
=
np
.
abs
(
np
.
gradient
(
t1
,
axis
=-
1
)).
mean
(
axis
=
1
)
prev_grad
=
0.0
for
grad
in
grad_mean
:
assert
grad
>
prev_grad
prev_grad
=
grad
def
test_timestep_defaults
(
self
):
embedding_dim
=
16
embedding_dim
=
16
timesteps
=
torch
.
arange
(
10
)
timesteps
=
torch
.
arange
(
10
)
t1
=
get_timestep_embedding
(
timesteps
,
embedding_dim
)
t1
=
get_timestep_embedding
(
timesteps
,
embedding_dim
)
t2
=
timestep_embedding
(
timesteps
,
embedding_dim
)
t2
=
get_timestep_embedding
(
t3
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
True
,
downscale_freq_factor
=
8
)
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
False
,
downscale_freq_shift
=
1
,
max_period
=
10_000
)
assert
torch
.
allclose
(
t1
.
cpu
(),
t2
.
cpu
(),
1e-3
)
def
test_timestep_flip_sin_cos
(
self
):
embedding_dim
=
16
timesteps
=
torch
.
arange
(
10
)
t1
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
True
)
t1
=
torch
.
cat
([
t1
[:,
embedding_dim
//
2
:],
t1
[:,
:
embedding_dim
//
2
]],
dim
=-
1
)
t2
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
False
)
assert
torch
.
allclose
(
t1
.
cpu
(),
t2
.
cpu
(),
1e-3
)
def
test_timestep_downscale_freq_shift
(
self
):
embedding_dim
=
16
timesteps
=
torch
.
arange
(
10
)
t1
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
downscale_freq_shift
=
0
)
t2
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
downscale_freq_shift
=
1
)
# get cosine half (vectors that are wrapped into cosine)
cosine_half
=
(
t1
-
t2
)[:,
embedding_dim
//
2
:]
# cosine needs to be negative
assert
(
np
.
abs
((
cosine_half
<=
0
).
numpy
())
-
1
).
sum
()
<
1e-5
import
ipdb
;
ipdb
.
set_trace
()
def
test_sinoid_embeddings_hardcoded
(
self
):
embedding_dim
=
64
timesteps
=
torch
.
arange
(
128
)
# standard unet, score_vde
t1
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
downscale_freq_shift
=
1
,
flip_sin_to_cos
=
False
)
# glide, ldm
t2
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
downscale_freq_shift
=
0
,
flip_sin_to_cos
=
True
)
# grad-tts
t3
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
scale
=
1000
)
assert
torch
.
allclose
(
t1
[
23
:
26
,
47
:
50
].
flatten
().
cpu
(),
torch
.
tensor
([
0.9646
,
0.9804
,
0.9892
,
0.9615
,
0.9787
,
0.9882
,
0.9582
,
0.9769
,
0.9872
]),
1e-3
,
)
assert
torch
.
allclose
(
t2
[
23
:
26
,
47
:
50
].
flatten
().
cpu
(),
torch
.
tensor
([
0.3019
,
0.2280
,
0.1716
,
0.3146
,
0.2377
,
0.1790
,
0.3272
,
0.2474
,
0.1864
]),
1e-3
,
)
assert
torch
.
allclose
(
t3
[
23
:
26
,
47
:
50
].
flatten
().
cpu
(),
torch
.
tensor
([
-
0.9801
,
-
0.9464
,
-
0.9349
,
-
0.3952
,
0.8887
,
-
0.9709
,
0.5299
,
-
0.2853
,
-
0.9927
]),
1e-3
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment