Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1ca0a755
Unverified
Commit
1ca0a755
authored
Aug 25, 2024
by
YiYi Xu
Committed by
GitHub
Aug 25, 2024
Browse files
refactor 3d rope for cogvideox (#9269)
* refactor 3d rope * repeat -> expand
parent
c1e6a32a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
52 deletions
+35
-52
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+35
-51
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+0
-1
No files found.
src/diffusers/models/embeddings.py
View file @
1ca0a755
...
...
@@ -391,15 +391,16 @@ def get_3d_rotary_pos_embed(
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
if
use_real
is
not
True
:
raise
ValueError
(
" `use_real = False` is not currently supported for get_3d_rotary_pos_embed"
)
start
,
stop
=
crops_coords
grid_h
=
np
.
linspace
(
start
[
0
],
stop
[
0
],
grid_size
[
0
],
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_w
=
np
.
linspace
(
start
[
1
],
stop
[
1
],
grid_size
[
1
],
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_size_h
,
grid_size_w
=
grid_size
grid_h
=
np
.
linspace
(
start
[
0
],
stop
[
0
],
grid_size_h
,
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_w
=
np
.
linspace
(
start
[
1
],
stop
[
1
],
grid_size_w
,
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_t
=
np
.
linspace
(
0
,
temporal_size
,
temporal_size
,
endpoint
=
False
,
dtype
=
np
.
float32
)
# Compute dimensions for each axis
...
...
@@ -408,54 +409,37 @@ def get_3d_rotary_pos_embed(
dim_w
=
embed_dim
//
8
*
3
# Temporal frequencies
freqs_t
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim_t
,
2
).
float
()
/
dim_t
))
grid_t
=
torch
.
from_numpy
(
grid_t
).
float
()
freqs_t
=
torch
.
einsum
(
"n , f -> n f"
,
grid_t
,
freqs_t
)
freqs_t
=
freqs_t
.
repeat_interleave
(
2
,
dim
=-
1
)
freqs_t
=
get_1d_rotary_pos_embed
(
dim_t
,
grid_t
,
use_real
=
True
)
# Spatial frequencies for height and width
freqs_h
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim_h
,
2
).
float
()
/
dim_h
))
freqs_w
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim_w
,
2
).
float
()
/
dim_w
))
grid_h
=
torch
.
from_numpy
(
grid_h
).
float
()
grid_w
=
torch
.
from_numpy
(
grid_w
).
float
()
freqs_h
=
torch
.
einsum
(
"n , f -> n f"
,
grid_h
,
freqs_h
)
freqs_w
=
torch
.
einsum
(
"n , f -> n f"
,
grid_w
,
freqs_w
)
freqs_h
=
freqs_h
.
repeat_interleave
(
2
,
dim
=-
1
)
freqs_w
=
freqs_w
.
repeat_interleave
(
2
,
dim
=-
1
)
# Broadcast and concatenate tensors along specified dimension
def
broadcast
(
tensors
,
dim
=-
1
):
num_tensors
=
len
(
tensors
)
shape_lens
=
{
len
(
t
.
shape
)
for
t
in
tensors
}
assert
len
(
shape_lens
)
==
1
,
"tensors must all have the same number of dimensions"
shape_len
=
list
(
shape_lens
)[
0
]
dim
=
(
dim
+
shape_len
)
if
dim
<
0
else
dim
dims
=
list
(
zip
(
*
(
list
(
t
.
shape
)
for
t
in
tensors
)))
expandable_dims
=
[(
i
,
val
)
for
i
,
val
in
enumerate
(
dims
)
if
i
!=
dim
]
assert
all
(
[
*
(
len
(
set
(
t
[
1
]))
<=
2
for
t
in
expandable_dims
)]
),
"invalid dimensions for broadcastable concatenation"
max_dims
=
[(
t
[
0
],
max
(
t
[
1
]))
for
t
in
expandable_dims
]
expanded_dims
=
[(
t
[
0
],
(
t
[
1
],)
*
num_tensors
)
for
t
in
max_dims
]
expanded_dims
.
insert
(
dim
,
(
dim
,
dims
[
dim
]))
expandable_shapes
=
list
(
zip
(
*
(
t
[
1
]
for
t
in
expanded_dims
)))
tensors
=
[
t
[
0
].
expand
(
*
t
[
1
])
for
t
in
zip
(
tensors
,
expandable_shapes
)]
return
torch
.
cat
(
tensors
,
dim
=
dim
)
freqs
=
broadcast
((
freqs_t
[:,
None
,
None
,
:],
freqs_h
[
None
,
:,
None
,
:],
freqs_w
[
None
,
None
,
:,
:]),
dim
=-
1
)
t
,
h
,
w
,
d
=
freqs
.
shape
freqs
=
freqs
.
view
(
t
*
h
*
w
,
d
)
# Generate sine and cosine components
sin
=
freqs
.
sin
()
cos
=
freqs
.
cos
()
if
use_real
:
return
cos
,
sin
else
:
freqs_cis
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
return
freqs_cis
freqs_h
=
get_1d_rotary_pos_embed
(
dim_h
,
grid_h
,
use_real
=
True
)
freqs_w
=
get_1d_rotary_pos_embed
(
dim_w
,
grid_w
,
use_real
=
True
)
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
def
combine_time_height_width
(
freqs_t
,
freqs_h
,
freqs_w
):
freqs_t
=
freqs_t
[:,
None
,
None
,
:].
expand
(
-
1
,
grid_size_h
,
grid_size_w
,
-
1
)
# temporal_size, grid_size_h, grid_size_w, dim_t
freqs_h
=
freqs_h
[
None
,
:,
None
,
:].
expand
(
temporal_size
,
-
1
,
grid_size_w
,
-
1
)
# temporal_size, grid_size_h, grid_size_2, dim_h
freqs_w
=
freqs_w
[
None
,
None
,
:,
:].
expand
(
temporal_size
,
grid_size_h
,
-
1
,
-
1
)
# temporal_size, grid_size_h, grid_size_2, dim_w
freqs
=
torch
.
cat
(
[
freqs_t
,
freqs_h
,
freqs_w
],
dim
=-
1
)
# temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
freqs
=
freqs
.
view
(
temporal_size
*
grid_size_h
*
grid_size_w
,
-
1
)
# (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
return
freqs
t_cos
,
t_sin
=
freqs_t
# both t_cos and t_sin has shape: temporal_size, dim_t
h_cos
,
h_sin
=
freqs_h
# both h_cos and h_sin has shape: grid_size_h, dim_h
w_cos
,
w_sin
=
freqs_w
# both w_cos and w_sin has shape: grid_size_w, dim_w
cos
=
combine_time_height_width
(
t_cos
,
h_cos
,
w_cos
)
sin
=
combine_time_height_width
(
t_sin
,
h_sin
,
w_sin
)
return
cos
,
sin
def
get_2d_rotary_pos_embed
(
embed_dim
,
crops_coords
,
grid_size
,
use_real
=
True
):
...
...
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
View file @
1ca0a755
...
...
@@ -463,7 +463,6 @@ class CogVideoXPipeline(DiffusionPipeline):
crops_coords
=
grid_crops_coords
,
grid_size
=
(
grid_height
,
grid_width
),
temporal_size
=
num_frames
,
use_real
=
True
,
)
freqs_cos
=
freqs_cos
.
to
(
device
=
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment