Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4f495b06
Unverified
Commit
4f495b06
authored
Aug 28, 2024
by
YiYi Xu
Committed by
GitHub
Aug 28, 2024
Browse files
rotary embedding refactor 2: update comments, fix dtype for use_real=False (#9312)
fix notes and dtype
parent
40c13fe5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
4 deletions
+8
-4
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+8
-4
No files found.
src/diffusers/models/embeddings.py
View file @
4f495b06
...
...
@@ -514,7 +514,7 @@ def get_1d_rotary_pos_embed(
linear_factor
=
1.0
,
ntk_factor
=
1.0
,
repeat_interleave_real
=
True
,
freqs_dtype
=
torch
.
float32
,
# torch.float32
(hunyuan, stable audio)
, torch.float64 (flux)
freqs_dtype
=
torch
.
float32
,
#
torch.float32, torch.float64 (flux)
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
...
...
@@ -551,15 +551,18 @@ def get_1d_rotary_pos_embed(
t
=
torch
.
from_numpy
(
pos
).
to
(
freqs
.
device
)
# type: ignore # [S]
freqs
=
torch
.
outer
(
t
,
freqs
)
# type: ignore # [S, D/2]
if
use_real
and
repeat_interleave_real
:
# flux, hunyuan-dit, cogvideox
freqs_cos
=
freqs
.
cos
().
repeat_interleave
(
2
,
dim
=
1
).
float
()
# [S, D]
freqs_sin
=
freqs
.
sin
().
repeat_interleave
(
2
,
dim
=
1
).
float
()
# [S, D]
return
freqs_cos
,
freqs_sin
elif
use_real
:
# stable audio
freqs_cos
=
torch
.
cat
([
freqs
.
cos
(),
freqs
.
cos
()],
dim
=-
1
).
float
()
# [S, D]
freqs_sin
=
torch
.
cat
([
freqs
.
sin
(),
freqs
.
sin
()],
dim
=-
1
).
float
()
# [S, D]
return
freqs_cos
,
freqs_sin
else
:
freqs_cis
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
).
float
()
# complex64 # [S, D/2]
# lumina
freqs_cis
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
# complex64 # [S, D/2]
return
freqs_cis
...
...
@@ -590,11 +593,11 @@ def apply_rotary_emb(
cos
,
sin
=
cos
.
to
(
x
.
device
),
sin
.
to
(
x
.
device
)
if
use_real_unbind_dim
==
-
1
:
# Use for
example in Lumina
# Use
d
for
flux, cogvideox, hunyuan-dit
x_real
,
x_imag
=
x
.
reshape
(
*
x
.
shape
[:
-
1
],
-
1
,
2
).
unbind
(
-
1
)
# [B, S, H, D//2]
x_rotated
=
torch
.
stack
([
-
x_imag
,
x_real
],
dim
=-
1
).
flatten
(
3
)
elif
use_real_unbind_dim
==
-
2
:
# Use for
example in
Stable Audio
# Use
d
for Stable Audio
x_real
,
x_imag
=
x
.
reshape
(
*
x
.
shape
[:
-
1
],
2
,
-
1
).
unbind
(
-
2
)
# [B, S, H, D//2]
x_rotated
=
torch
.
cat
([
-
x_imag
,
x_real
],
dim
=-
1
)
else
:
...
...
@@ -604,6 +607,7 @@ def apply_rotary_emb(
return
out
else
:
# used for lumina
x_rotated
=
torch
.
view_as_complex
(
x
.
float
().
reshape
(
*
x
.
shape
[:
-
1
],
-
1
,
2
))
freqs_cis
=
freqs_cis
.
unsqueeze
(
2
)
x_out
=
torch
.
view_as_real
(
x_rotated
*
freqs_cis
).
flatten
(
3
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment