Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
1ca0a755
Unverified
Commit
1ca0a755
authored
Aug 25, 2024
by
YiYi Xu
Committed by
GitHub
Aug 25, 2024
Browse files
refactor 3d rope for cogvideox (#9269)
* refactor 3d rope * repeat -> expand
parent
c1e6a32a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
52 deletions
+35
-52
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+35
-51
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+0
-1
No files found.
src/diffusers/models/embeddings.py
View file @
1ca0a755
...
@@ -391,15 +391,16 @@ def get_3d_rotary_pos_embed(
...
@@ -391,15 +391,16 @@ def get_3d_rotary_pos_embed(
The size of the temporal dimension.
The size of the temporal dimension.
theta (`float`):
theta (`float`):
Scaling factor for frequency computation.
Scaling factor for frequency computation.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
"""
if
use_real
is
not
True
:
raise
ValueError
(
" `use_real = False` is not currently supported for get_3d_rotary_pos_embed"
)
start
,
stop
=
crops_coords
start
,
stop
=
crops_coords
grid_h
=
np
.
linspace
(
start
[
0
],
stop
[
0
],
grid_size
[
0
],
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_size_h
,
grid_size_w
=
grid_size
grid_w
=
np
.
linspace
(
start
[
1
],
stop
[
1
],
grid_size
[
1
],
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_h
=
np
.
linspace
(
start
[
0
],
stop
[
0
],
grid_size_h
,
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_w
=
np
.
linspace
(
start
[
1
],
stop
[
1
],
grid_size_w
,
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_t
=
np
.
linspace
(
0
,
temporal_size
,
temporal_size
,
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_t
=
np
.
linspace
(
0
,
temporal_size
,
temporal_size
,
endpoint
=
False
,
dtype
=
np
.
float32
)
# Compute dimensions for each axis
# Compute dimensions for each axis
...
@@ -408,54 +409,37 @@ def get_3d_rotary_pos_embed(
...
@@ -408,54 +409,37 @@ def get_3d_rotary_pos_embed(
dim_w
=
embed_dim
//
8
*
3
dim_w
=
embed_dim
//
8
*
3
# Temporal frequencies
# Temporal frequencies
freqs_t
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim_t
,
2
).
float
()
/
dim_t
))
freqs_t
=
get_1d_rotary_pos_embed
(
dim_t
,
grid_t
,
use_real
=
True
)
grid_t
=
torch
.
from_numpy
(
grid_t
).
float
()
freqs_t
=
torch
.
einsum
(
"n , f -> n f"
,
grid_t
,
freqs_t
)
freqs_t
=
freqs_t
.
repeat_interleave
(
2
,
dim
=-
1
)
# Spatial frequencies for height and width
# Spatial frequencies for height and width
freqs_h
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim_h
,
2
).
float
()
/
dim_h
))
freqs_h
=
get_1d_rotary_pos_embed
(
dim_h
,
grid_h
,
use_real
=
True
)
freqs_w
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim_w
,
2
).
float
()
/
dim_w
))
freqs_w
=
get_1d_rotary_pos_embed
(
dim_w
,
grid_w
,
use_real
=
True
)
grid_h
=
torch
.
from_numpy
(
grid_h
).
float
()
grid_w
=
torch
.
from_numpy
(
grid_w
).
float
()
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
freqs_h
=
torch
.
einsum
(
"n , f -> n f"
,
grid_h
,
freqs_h
)
def
combine_time_height_width
(
freqs_t
,
freqs_h
,
freqs_w
):
freqs_w
=
torch
.
einsum
(
"n , f -> n f"
,
grid_w
,
freqs_w
)
freqs_t
=
freqs_t
[:,
None
,
None
,
:].
expand
(
freqs_h
=
freqs_h
.
repeat_interleave
(
2
,
dim
=-
1
)
-
1
,
grid_size_h
,
grid_size_w
,
-
1
freqs_w
=
freqs_w
.
repeat_interleave
(
2
,
dim
=-
1
)
)
# temporal_size, grid_size_h, grid_size_w, dim_t
freqs_h
=
freqs_h
[
None
,
:,
None
,
:].
expand
(
# Broadcast and concatenate tensors along specified dimension
temporal_size
,
-
1
,
grid_size_w
,
-
1
def
broadcast
(
tensors
,
dim
=-
1
):
)
# temporal_size, grid_size_h, grid_size_2, dim_h
num_tensors
=
len
(
tensors
)
freqs_w
=
freqs_w
[
None
,
None
,
:,
:].
expand
(
shape_lens
=
{
len
(
t
.
shape
)
for
t
in
tensors
}
temporal_size
,
grid_size_h
,
-
1
,
-
1
assert
len
(
shape_lens
)
==
1
,
"tensors must all have the same number of dimensions"
)
# temporal_size, grid_size_h, grid_size_2, dim_w
shape_len
=
list
(
shape_lens
)[
0
]
dim
=
(
dim
+
shape_len
)
if
dim
<
0
else
dim
freqs
=
torch
.
cat
(
dims
=
list
(
zip
(
*
(
list
(
t
.
shape
)
for
t
in
tensors
)))
[
freqs_t
,
freqs_h
,
freqs_w
],
dim
=-
1
expandable_dims
=
[(
i
,
val
)
for
i
,
val
in
enumerate
(
dims
)
if
i
!=
dim
]
)
# temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
assert
all
(
freqs
=
freqs
.
view
(
[
*
(
len
(
set
(
t
[
1
]))
<=
2
for
t
in
expandable_dims
)]
temporal_size
*
grid_size_h
*
grid_size_w
,
-
1
),
"invalid dimensions for broadcastable concatenation"
)
# (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
max_dims
=
[(
t
[
0
],
max
(
t
[
1
]))
for
t
in
expandable_dims
]
return
freqs
expanded_dims
=
[(
t
[
0
],
(
t
[
1
],)
*
num_tensors
)
for
t
in
max_dims
]
expanded_dims
.
insert
(
dim
,
(
dim
,
dims
[
dim
]))
t_cos
,
t_sin
=
freqs_t
# both t_cos and t_sin has shape: temporal_size, dim_t
expandable_shapes
=
list
(
zip
(
*
(
t
[
1
]
for
t
in
expanded_dims
)))
h_cos
,
h_sin
=
freqs_h
# both h_cos and h_sin has shape: grid_size_h, dim_h
tensors
=
[
t
[
0
].
expand
(
*
t
[
1
])
for
t
in
zip
(
tensors
,
expandable_shapes
)]
w_cos
,
w_sin
=
freqs_w
# both w_cos and w_sin has shape: grid_size_w, dim_w
return
torch
.
cat
(
tensors
,
dim
=
dim
)
cos
=
combine_time_height_width
(
t_cos
,
h_cos
,
w_cos
)
sin
=
combine_time_height_width
(
t_sin
,
h_sin
,
w_sin
)
freqs
=
broadcast
((
freqs_t
[:,
None
,
None
,
:],
freqs_h
[
None
,
:,
None
,
:],
freqs_w
[
None
,
None
,
:,
:]),
dim
=-
1
)
return
cos
,
sin
t
,
h
,
w
,
d
=
freqs
.
shape
freqs
=
freqs
.
view
(
t
*
h
*
w
,
d
)
# Generate sine and cosine components
sin
=
freqs
.
sin
()
cos
=
freqs
.
cos
()
if
use_real
:
return
cos
,
sin
else
:
freqs_cis
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
return
freqs_cis
def
get_2d_rotary_pos_embed
(
embed_dim
,
crops_coords
,
grid_size
,
use_real
=
True
):
def
get_2d_rotary_pos_embed
(
embed_dim
,
crops_coords
,
grid_size
,
use_real
=
True
):
...
...
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
View file @
1ca0a755
...
@@ -463,7 +463,6 @@ class CogVideoXPipeline(DiffusionPipeline):
...
@@ -463,7 +463,6 @@ class CogVideoXPipeline(DiffusionPipeline):
crops_coords
=
grid_crops_coords
,
crops_coords
=
grid_crops_coords
,
grid_size
=
(
grid_height
,
grid_width
),
grid_size
=
(
grid_height
,
grid_width
),
temporal_size
=
num_frames
,
temporal_size
=
num_frames
,
use_real
=
True
,
)
)
freqs_cos
=
freqs_cos
.
to
(
device
=
device
)
freqs_cos
=
freqs_cos
.
to
(
device
=
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment