Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
168e5b7f
Commit
168e5b7f
authored
Jun 27, 2022
by
Patrick von Platen
Browse files
add embeddings
parent
3562a3e6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
80 additions
and
0 deletions
+80
-0
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+80
-0
No files found.
src/diffusers/models/embeddings.py
View file @
168e5b7f
...
@@ -54,3 +54,83 @@ def timestep_embedding(timesteps, dim, max_period=10000):
...
@@ -54,3 +54,83 @@ def timestep_embedding(timesteps, dim, max_period=10000):
if
dim
%
2
:
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
return
embedding
# unet_grad_tts.py
class
SinusoidalPosEmb
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
SinusoidalPosEmb
,
self
).
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
,
scale
=
1000
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
).
float
()
*
-
emb
)
emb
=
scale
*
x
.
unsqueeze
(
1
)
*
emb
.
unsqueeze
(
0
)
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
# unet_ldm.py
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
# unet_rl.py
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
)
*
-
emb
)
emb
=
x
[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
# unet_sde_score_estimation.py
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
max_positions
=
10000
):
assert
len
(
timesteps
.
shape
)
==
1
# and timesteps.dtype == tf.int32
half_dim
=
embedding_dim
//
2
# magic number 10000 is from transformers
emb
=
math
.
log
(
max_positions
)
/
(
half_dim
-
1
)
# emb = math.log(2.) / (half_dim - 1)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
,
device
=
timesteps
.
device
)
*
-
emb
)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
F
.
pad
(
emb
,
(
0
,
1
),
mode
=
"constant"
)
assert
emb
.
shape
==
(
timesteps
.
shape
[
0
],
embedding_dim
)
return
emb
# unet_sde_score_estimation.py
class
GaussianFourierProjection
(
nn
.
Module
):
"""Gaussian Fourier embeddings for noise levels."""
def
__init__
(
self
,
embedding_size
=
256
,
scale
=
1.0
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
torch
.
randn
(
embedding_size
)
*
scale
,
requires_grad
=
False
)
def
forward
(
self
,
x
):
x_proj
=
x
[:,
None
]
*
self
.
W
[
None
,
:]
*
2
*
np
.
pi
return
torch
.
cat
([
torch
.
sin
(
x_proj
),
torch
.
cos
(
x_proj
)],
dim
=-
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment