Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
02a76c2c
Commit
02a76c2c
authored
Jun 27, 2022
by
Patrick von Platen
Browse files
consolidate timestep embeds
parent
014ebc59
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
173 additions
and
853 deletions
+173
-853
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+84
-65
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+22
-21
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+23
-22
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+22
-30
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+12
-11
tests/test_layers_utils.py
tests/test_layers_utils.py
+10
-704
No files found.
src/diffusers/models/embeddings.py
View file @
02a76c2c
...
@@ -11,49 +11,104 @@
...
@@ -11,49 +11,104 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
import
math
import
numpy
as
np
from
torch
import
nn
import
torch.nn.functional
as
F
# unet.py
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
False
,
downscale_freq_shift
=
1
,
max_period
=
10000
):
"""
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
Create sinusoidal timestep embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
"""
assert
len
(
timesteps
.
shape
)
==
1
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
/
(
embedding_dim
//
2
-
downscale_freq_shift
)
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
timesteps
[:,
None
].
float
()
*
emb
[
None
,
:]
# concat sine and cosine embeddings
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
# flip sine and cosine embeddings
if
flip_sin_to_cos
:
emb
=
torch
.
cat
([
emb
[:,
half_dim
:],
emb
[:,
:
half_dim
]],
dim
=-
1
)
# zero pad
if
embedding_dim
%
2
==
1
:
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
return
emb
# unet_glide.py
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
#def get_timestep_embedding(timesteps, embedding_dim):
These may be fractional.
# """
:param dim: the dimension of the output.
# This matches the implementation in Denoising Diffusion Probabilistic Models:
:param max_period: controls the minimum frequency of the embeddings.
# From Fairseq.
:return: an [N x dim] Tensor of positional embeddings.
# Build sinusoidal embeddings.
"""
# This matches the implementation in tensor2tensor, but differs slightly
half
=
dim
//
2
# from the description in Section 3.5 of "Attention Is All You Need".
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
# """
device
=
timesteps
.
device
# assert len(timesteps.shape) == 1
)
#
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
# half_dim = embedding_dim // 2
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
# emb = math.log(10000) / (half_dim - 1)
if
dim
%
2
:
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
# emb = emb.to(device=timesteps.device)
return
embedding
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
#def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None, :]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
#def a_get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
# half_dim = embedding_dim // 2
# magic number 10000 is from transformers
# emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = F.pad(emb, (0, 1), mode="constant")
# assert emb.shape == (timesteps.shape[0], embedding_dim)
# return emb
# unet_grad_tts.py
# unet_grad_tts.py
class
SinusoidalPosEmb
(
torch
.
nn
.
Module
):
class
SinusoidalPosEmb
(
torch
.
nn
.
Module
):
...
@@ -70,26 +125,6 @@ class SinusoidalPosEmb(torch.nn.Module):
...
@@ -70,26 +125,6 @@ class SinusoidalPosEmb(torch.nn.Module):
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
return
emb
# unet_ldm.py
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
# unet_rl.py
# unet_rl.py
class
SinusoidalPosEmb
(
nn
.
Module
):
class
SinusoidalPosEmb
(
nn
.
Module
):
...
@@ -106,22 +141,6 @@ class SinusoidalPosEmb(nn.Module):
...
@@ -106,22 +141,6 @@ class SinusoidalPosEmb(nn.Module):
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
return
emb
# unet_sde_score_estimation.py
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
max_positions
=
10000
):
assert
len
(
timesteps
.
shape
)
==
1
# and timesteps.dtype == tf.int32
half_dim
=
embedding_dim
//
2
# magic number 10000 is from transformers
emb
=
math
.
log
(
max_positions
)
/
(
half_dim
-
1
)
# emb = math.log(2.) / (half_dim - 1)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
,
device
=
timesteps
.
device
)
*
-
emb
)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
F
.
pad
(
emb
,
(
0
,
1
),
mode
=
"constant"
)
assert
emb
.
shape
==
(
timesteps
.
shape
[
0
],
embedding_dim
)
return
emb
# unet_sde_score_estimation.py
# unet_sde_score_estimation.py
class
GaussianFourierProjection
(
nn
.
Module
):
class
GaussianFourierProjection
(
nn
.
Module
):
...
...
src/diffusers/models/unet.py
View file @
02a76c2c
...
@@ -30,27 +30,28 @@ from tqdm import tqdm
...
@@ -30,27 +30,28 @@ from tqdm import tqdm
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
#def get_timestep_embedding(timesteps, embedding_dim):
This matches the implementation in Denoising Diffusion Probabilistic Models:
# """
From Fairseq.
# This matches the implementation in Denoising Diffusion Probabilistic Models:
Build sinusoidal embeddings.
# From Fairseq.
This matches the implementation in tensor2tensor, but differs slightly
# Build sinusoidal embeddings.
from the description in Section 3.5 of "Attention Is All You Need".
# This matches the implementation in tensor2tensor, but differs slightly
"""
# from the description in Section 3.5 of "Attention Is All You Need".
assert
len
(
timesteps
.
shape
)
==
1
# """
# assert len(timesteps.shape) == 1
half_dim
=
embedding_dim
//
2
#
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
# half_dim = embedding_dim // 2
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
# emb = math.log(10000) / (half_dim - 1)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
# emb = emb.to(device=timesteps.device)
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
# emb = timesteps.float()[:, None] * emb[None, :]
if
embedding_dim
%
2
==
1
:
# zero pad
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
# if embedding_dim % 2 == 1: # zero pad
return
emb
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
# return emb
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
...
src/diffusers/models/unet_glide.py
View file @
02a76c2c
...
@@ -7,6 +7,7 @@ import torch.nn.functional as F
...
@@ -7,6 +7,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -86,25 +87,25 @@ def normalization(channels, swish=0.0):
...
@@ -86,25 +87,25 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
#
def timestep_embedding(timesteps, dim, max_period=10000):
"""
#
"""
Create sinusoidal timestep embeddings.
#
Create sinusoidal timestep embeddings.
#
:param timesteps: a 1-D Tensor of N indices, one per batch element.
#
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
#
These may be fractional.
:param dim: the dimension of the output.
#
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
#
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
#
:return: an [N x dim] Tensor of positional embeddings.
"""
#
"""
half
=
dim
//
2
#
half = dim // 2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
#
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device
=
timesteps
.
device
#
device=timesteps.device
)
#
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
#
args = timesteps[:, None].float() * freqs[None]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
#
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if
dim
%
2
:
#
if dim % 2:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
#
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return
embedding
#
return embedding
def
zero_module
(
module
):
def
zero_module
(
module
):
...
@@ -627,7 +628,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -627,7 +628,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
"""
"""
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
h
=
x
.
type
(
self
.
dtype
)
h
=
x
.
type
(
self
.
dtype
)
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
...
@@ -714,7 +715,7 @@ class GlideTextToImageUNetModel(GlideUNetModel):
...
@@ -714,7 +715,7 @@ class GlideTextToImageUNetModel(GlideUNetModel):
def
forward
(
self
,
x
,
timesteps
,
transformer_out
=
None
):
def
forward
(
self
,
x
,
timesteps
,
transformer_out
=
None
):
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
# project the last token
# project the last token
transformer_proj
=
self
.
transformer_proj
(
transformer_out
[:,
-
1
])
transformer_proj
=
self
.
transformer_proj
(
transformer_out
[:,
-
1
])
...
@@ -806,7 +807,7 @@ class GlideSuperResUNetModel(GlideUNetModel):
...
@@ -806,7 +807,7 @@ class GlideSuperResUNetModel(GlideUNetModel):
x
=
torch
.
cat
([
x
,
upsampled
],
dim
=
1
)
x
=
torch
.
cat
([
x
,
upsampled
],
dim
=
1
)
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
h
=
x
h
=
x
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
...
...
src/diffusers/models/unet_ldm.py
View file @
02a76c2c
...
@@ -16,6 +16,7 @@ except:
...
@@ -16,6 +16,7 @@ except:
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
def
exists
(
val
):
def
exists
(
val
):
...
@@ -316,34 +317,25 @@ def normalization(channels, swish=0.0):
...
@@ -316,34 +317,25 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
#def timestep_embedding(timesteps, dim, max_period=10000):
"""
# """
Create sinusoidal timestep embeddings.
# Create sinusoidal timestep embeddings.
#
:param timesteps: a 1-D Tensor of N indices, one per batch element.
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
# These may be fractional.
:param dim: the dimension of the output.
# :param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
# :param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
"""
# """
half
=
dim
//
2
# half = dim // 2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device
=
timesteps
.
device
# device=timesteps.device
)
# )
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
# args = timesteps[:, None].float() * freqs[None]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if
dim
%
2
:
# if dim % 2:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return
embedding
# return embedding
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
## go
## go
...
@@ -1026,7 +1018,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -1026,7 +1018,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
hs
=
[]
hs
=
[]
if
not
torch
.
is_tensor
(
timesteps
):
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
x
.
device
)
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
x
.
device
)
t_emb
=
timestep_embedding
(
timesteps
,
self
.
model_channels
)
t_emb
=
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
emb
=
self
.
time_embed
(
t_emb
)
emb
=
self
.
time_embed
(
t_emb
)
if
self
.
num_classes
is
not
None
:
if
self
.
num_classes
is
not
None
:
...
@@ -1240,7 +1232,7 @@ class EncoderUNetModel(nn.Module):
...
@@ -1240,7 +1232,7 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
:return: an [N x K] Tensor of outputs.
"""
"""
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
get_
timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
results
=
[]
results
=
[]
h
=
x
.
type
(
self
.
dtype
)
h
=
x
.
type
(
self
.
dtype
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
02a76c2c
...
@@ -26,6 +26,7 @@ import torch.nn.functional as F
...
@@ -26,6 +26,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
...
@@ -381,21 +382,21 @@ def get_act(nonlinearity):
...
@@ -381,21 +382,21 @@ def get_act(nonlinearity):
raise
NotImplementedError
(
"activation function does not exist!"
)
raise
NotImplementedError
(
"activation function does not exist!"
)
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
max_positions
=
10000
):
#
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert
len
(
timesteps
.
shape
)
==
1
# and timesteps.dtype == tf.int32
#
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim
=
embedding_dim
//
2
#
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
# magic number 10000 is from transformers
emb
=
math
.
log
(
max_positions
)
/
(
half_dim
-
1
)
#
emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
,
device
=
timesteps
.
device
)
*
-
emb
)
#
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
#
emb = timesteps.float()[:, None] * emb[None, :]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
#
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if
embedding_dim
%
2
==
1
:
# zero pad
#
if embedding_dim % 2 == 1: # zero pad
emb
=
F
.
pad
(
emb
,
(
0
,
1
),
mode
=
"constant"
)
#
emb = F.pad(emb, (0, 1), mode="constant")
assert
emb
.
shape
==
(
timesteps
.
shape
[
0
],
embedding_dim
)
#
assert emb.shape == (timesteps.shape[0], embedding_dim)
return
emb
#
return emb
def
default_init
(
scale
=
1.0
):
def
default_init
(
scale
=
1.0
):
...
...
tests/test_layers_utils.py
View file @
02a76c2c
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment