Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
61d96c3a
Unverified
Commit
61d96c3a
authored
Aug 29, 2024
by
YiYi Xu
Committed by
GitHub
Aug 30, 2024
Browse files
refactor rotary embedding 3: so it is not on cpu (#9307)
change get_1d_rotary to accept pos as torch tensors
parent
4f495b06
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
4 deletions
+7
-4
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+7
-4
No files found.
src/diffusers/models/embeddings.py
View file @
61d96c3a
...
@@ -545,11 +545,14 @@ def get_1d_rotary_pos_embed(
...
@@ -545,11 +545,14 @@ def get_1d_rotary_pos_embed(
assert
dim
%
2
==
0
assert
dim
%
2
==
0
if
isinstance
(
pos
,
int
):
if
isinstance
(
pos
,
int
):
pos
=
np
.
arange
(
pos
)
pos
=
torch
.
arange
(
pos
)
if
isinstance
(
pos
,
np
.
ndarray
):
pos
=
torch
.
from_numpy
(
pos
)
# type: ignore # [S]
theta
=
theta
*
ntk_factor
theta
=
theta
*
ntk_factor
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
freqs_dtype
)[:
(
dim
//
2
)]
/
dim
))
/
linear_factor
# [D/2]
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
freqs_dtype
)[:
(
dim
//
2
)]
/
dim
))
/
linear_factor
# [D/2]
t
=
torch
.
from_numpy
(
pos
).
to
(
freqs
.
device
)
# type: ignore # [S]
freqs
=
freqs
.
to
(
pos
.
device
)
freqs
=
torch
.
outer
(
t
,
freqs
)
# type: ignore # [S, D/2]
freqs
=
torch
.
outer
(
pos
,
freqs
)
# type: ignore # [S, D/2]
if
use_real
and
repeat_interleave_real
:
if
use_real
and
repeat_interleave_real
:
# flux, hunyuan-dit, cogvideox
# flux, hunyuan-dit, cogvideox
freqs_cos
=
freqs
.
cos
().
repeat_interleave
(
2
,
dim
=
1
).
float
()
# [S, D]
freqs_cos
=
freqs
.
cos
().
repeat_interleave
(
2
,
dim
=
1
).
float
()
# [S, D]
...
@@ -626,7 +629,7 @@ class FluxPosEmbed(nn.Module):
...
@@ -626,7 +629,7 @@ class FluxPosEmbed(nn.Module):
n_axes
=
ids
.
shape
[
-
1
]
n_axes
=
ids
.
shape
[
-
1
]
cos_out
=
[]
cos_out
=
[]
sin_out
=
[]
sin_out
=
[]
pos
=
ids
.
squeeze
().
float
()
.
cpu
().
numpy
()
pos
=
ids
.
squeeze
().
float
()
is_mps
=
ids
.
device
.
type
==
"mps"
is_mps
=
ids
.
device
.
type
==
"mps"
freqs_dtype
=
torch
.
float32
if
is_mps
else
torch
.
float64
freqs_dtype
=
torch
.
float32
if
is_mps
else
torch
.
float64
for
i
in
range
(
n_axes
):
for
i
in
range
(
n_axes
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment