Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
1ca0a755
Unverified
Commit
1ca0a755
authored
Aug 25, 2024
by
YiYi Xu
Committed by
GitHub
Aug 25, 2024
Browse files
refactor 3d rope for cogvideox (#9269)
* refactor 3d rope * repeat -> expand
parent
c1e6a32a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
52 deletions
+35
-52
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+35
-51
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+0
-1
No files found.
src/diffusers/models/embeddings.py
View file @
1ca0a755
...
...
@@ -391,15 +391,16 @@ def get_3d_rotary_pos_embed(
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
if
use_real
is
not
True
:
raise
ValueError
(
" `use_real = False` is not currently supported for get_3d_rotary_pos_embed"
)
start
,
stop
=
crops_coords
grid_h
=
np
.
linspace
(
start
[
0
],
stop
[
0
],
grid_size
[
0
],
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_w
=
np
.
linspace
(
start
[
1
],
stop
[
1
],
grid_size
[
1
],
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_size_h
,
grid_size_w
=
grid_size
grid_h
=
np
.
linspace
(
start
[
0
],
stop
[
0
],
grid_size_h
,
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_w
=
np
.
linspace
(
start
[
1
],
stop
[
1
],
grid_size_w
,
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_t
=
np
.
linspace
(
0
,
temporal_size
,
temporal_size
,
endpoint
=
False
,
dtype
=
np
.
float32
)
# Compute dimensions for each axis
...
...
@@ -408,54 +409,37 @@ def get_3d_rotary_pos_embed(
dim_w
=
embed_dim
//
8
*
3
# Temporal frequencies
freqs_t
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim_t
,
2
).
float
()
/
dim_t
))
grid_t
=
torch
.
from_numpy
(
grid_t
).
float
()
freqs_t
=
torch
.
einsum
(
"n , f -> n f"
,
grid_t
,
freqs_t
)
freqs_t
=
freqs_t
.
repeat_interleave
(
2
,
dim
=-
1
)
freqs_t
=
get_1d_rotary_pos_embed
(
dim_t
,
grid_t
,
use_real
=
True
)
# Spatial frequencies for height and width
freqs_h
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim_h
,
2
).
float
()
/
dim_h
))
freqs_w
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim_w
,
2
).
float
()
/
dim_w
))
grid_h
=
torch
.
from_numpy
(
grid_h
).
float
()
grid_w
=
torch
.
from_numpy
(
grid_w
).
float
()
freqs_h
=
torch
.
einsum
(
"n , f -> n f"
,
grid_h
,
freqs_h
)
freqs_w
=
torch
.
einsum
(
"n , f -> n f"
,
grid_w
,
freqs_w
)
freqs_h
=
freqs_h
.
repeat_interleave
(
2
,
dim
=-
1
)
freqs_w
=
freqs_w
.
repeat_interleave
(
2
,
dim
=-
1
)
# Broadcast and concatenate tensors along specified dimension
def
broadcast
(
tensors
,
dim
=-
1
):
num_tensors
=
len
(
tensors
)
shape_lens
=
{
len
(
t
.
shape
)
for
t
in
tensors
}
assert
len
(
shape_lens
)
==
1
,
"tensors must all have the same number of dimensions"
shape_len
=
list
(
shape_lens
)[
0
]
dim
=
(
dim
+
shape_len
)
if
dim
<
0
else
dim
dims
=
list
(
zip
(
*
(
list
(
t
.
shape
)
for
t
in
tensors
)))
expandable_dims
=
[(
i
,
val
)
for
i
,
val
in
enumerate
(
dims
)
if
i
!=
dim
]
assert
all
(
[
*
(
len
(
set
(
t
[
1
]))
<=
2
for
t
in
expandable_dims
)]
),
"invalid dimensions for broadcastable concatenation"
max_dims
=
[(
t
[
0
],
max
(
t
[
1
]))
for
t
in
expandable_dims
]
expanded_dims
=
[(
t
[
0
],
(
t
[
1
],)
*
num_tensors
)
for
t
in
max_dims
]
expanded_dims
.
insert
(
dim
,
(
dim
,
dims
[
dim
]))
expandable_shapes
=
list
(
zip
(
*
(
t
[
1
]
for
t
in
expanded_dims
)))
tensors
=
[
t
[
0
].
expand
(
*
t
[
1
])
for
t
in
zip
(
tensors
,
expandable_shapes
)]
return
torch
.
cat
(
tensors
,
dim
=
dim
)
freqs
=
broadcast
((
freqs_t
[:,
None
,
None
,
:],
freqs_h
[
None
,
:,
None
,
:],
freqs_w
[
None
,
None
,
:,
:]),
dim
=-
1
)
t
,
h
,
w
,
d
=
freqs
.
shape
freqs
=
freqs
.
view
(
t
*
h
*
w
,
d
)
# Generate sine and cosine components
sin
=
freqs
.
sin
()
cos
=
freqs
.
cos
()
if
use_real
:
freqs_h
=
get_1d_rotary_pos_embed
(
dim_h
,
grid_h
,
use_real
=
True
)
freqs_w
=
get_1d_rotary_pos_embed
(
dim_w
,
grid_w
,
use_real
=
True
)
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
def
combine_time_height_width
(
freqs_t
,
freqs_h
,
freqs_w
):
freqs_t
=
freqs_t
[:,
None
,
None
,
:].
expand
(
-
1
,
grid_size_h
,
grid_size_w
,
-
1
)
# temporal_size, grid_size_h, grid_size_w, dim_t
freqs_h
=
freqs_h
[
None
,
:,
None
,
:].
expand
(
temporal_size
,
-
1
,
grid_size_w
,
-
1
)
# temporal_size, grid_size_h, grid_size_2, dim_h
freqs_w
=
freqs_w
[
None
,
None
,
:,
:].
expand
(
temporal_size
,
grid_size_h
,
-
1
,
-
1
)
# temporal_size, grid_size_h, grid_size_2, dim_w
freqs
=
torch
.
cat
(
[
freqs_t
,
freqs_h
,
freqs_w
],
dim
=-
1
)
# temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
freqs
=
freqs
.
view
(
temporal_size
*
grid_size_h
*
grid_size_w
,
-
1
)
# (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
return
freqs
t_cos
,
t_sin
=
freqs_t
# both t_cos and t_sin has shape: temporal_size, dim_t
h_cos
,
h_sin
=
freqs_h
# both h_cos and h_sin has shape: grid_size_h, dim_h
w_cos
,
w_sin
=
freqs_w
# both w_cos and w_sin has shape: grid_size_w, dim_w
cos
=
combine_time_height_width
(
t_cos
,
h_cos
,
w_cos
)
sin
=
combine_time_height_width
(
t_sin
,
h_sin
,
w_sin
)
return
cos
,
sin
else
:
freqs_cis
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
return
freqs_cis
def
get_2d_rotary_pos_embed
(
embed_dim
,
crops_coords
,
grid_size
,
use_real
=
True
):
...
...
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
View file @
1ca0a755
...
...
@@ -463,7 +463,6 @@ class CogVideoXPipeline(DiffusionPipeline):
crops_coords
=
grid_crops_coords
,
grid_size
=
(
grid_height
,
grid_width
),
temporal_size
=
num_frames
,
use_real
=
True
,
)
freqs_cos
=
freqs_cos
.
to
(
device
=
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment