Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4c4b323c
Unverified
Commit
4c4b323c
authored
Dec 10, 2024
by
hlky
Committed by
GitHub
Dec 10, 2024
Browse files
Use `torch` in `get_3d_rotary_pos_embed`/`_allegro` (#10161)
Use torch in get_3d_rotary_pos_embed/_allegro
parent
22d3a826
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
41 additions
and
32 deletions
+41
-32
examples/cogvideo/train_cogvideox_image_to_video_lora.py
examples/cogvideo/train_cogvideox_image_to_video_lora.py
+1
-2
examples/cogvideo/train_cogvideox_lora.py
examples/cogvideo/train_cogvideox_lora.py
+1
-2
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+27
-13
src/diffusers/pipelines/allegro/pipeline_allegro.py
src/diffusers/pipelines/allegro/pipeline_allegro.py
+4
-7
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+2
-2
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
...sers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
+2
-2
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
...sers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
+2
-2
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
...sers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
+2
-2
No files found.
examples/cogvideo/train_cogvideox_image_to_video_lora.py
View file @
4c4b323c
...
@@ -872,10 +872,9 @@ def prepare_rotary_positional_embeddings(
...
@@ -872,10 +872,9 @@ def prepare_rotary_positional_embeddings(
crops_coords
=
grid_crops_coords
,
crops_coords
=
grid_crops_coords
,
grid_size
=
(
grid_height
,
grid_width
),
grid_size
=
(
grid_height
,
grid_width
),
temporal_size
=
num_frames
,
temporal_size
=
num_frames
,
device
=
device
,
)
)
freqs_cos
=
freqs_cos
.
to
(
device
=
device
)
freqs_sin
=
freqs_sin
.
to
(
device
=
device
)
return
freqs_cos
,
freqs_sin
return
freqs_cos
,
freqs_sin
...
...
examples/cogvideo/train_cogvideox_lora.py
View file @
4c4b323c
...
@@ -894,10 +894,9 @@ def prepare_rotary_positional_embeddings(
...
@@ -894,10 +894,9 @@ def prepare_rotary_positional_embeddings(
crops_coords
=
grid_crops_coords
,
crops_coords
=
grid_crops_coords
,
grid_size
=
(
grid_height
,
grid_width
),
grid_size
=
(
grid_height
,
grid_width
),
temporal_size
=
num_frames
,
temporal_size
=
num_frames
,
device
=
device
,
)
)
freqs_cos
=
freqs_cos
.
to
(
device
=
device
)
freqs_sin
=
freqs_sin
.
to
(
device
=
device
)
return
freqs_cos
,
freqs_sin
return
freqs_cos
,
freqs_sin
...
...
src/diffusers/models/embeddings.py
View file @
4c4b323c
...
@@ -594,6 +594,7 @@ def get_3d_rotary_pos_embed(
...
@@ -594,6 +594,7 @@ def get_3d_rotary_pos_embed(
use_real
:
bool
=
True
,
use_real
:
bool
=
True
,
grid_type
:
str
=
"linspace"
,
grid_type
:
str
=
"linspace"
,
max_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
max_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""
"""
RoPE for video tokens with 3D structure.
RoPE for video tokens with 3D structure.
...
@@ -621,16 +622,22 @@ def get_3d_rotary_pos_embed(
...
@@ -621,16 +622,22 @@ def get_3d_rotary_pos_embed(
if
grid_type
==
"linspace"
:
if
grid_type
==
"linspace"
:
start
,
stop
=
crops_coords
start
,
stop
=
crops_coords
grid_size_h
,
grid_size_w
=
grid_size
grid_size_h
,
grid_size_w
=
grid_size
grid_h
=
np
.
linspace
(
start
[
0
],
stop
[
0
],
grid_size_h
,
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_h
=
torch
.
linspace
(
grid_w
=
np
.
linspace
(
start
[
1
],
stop
[
1
],
grid_size_w
,
endpoint
=
False
,
dtype
=
np
.
float32
)
start
[
0
],
stop
[
0
]
*
(
grid_size_h
-
1
)
/
grid_size_h
,
grid_size_h
,
device
=
device
,
dtype
=
torch
.
float32
grid_t
=
np
.
arange
(
temporal_size
,
dtype
=
np
.
float32
)
)
grid_t
=
np
.
linspace
(
0
,
temporal_size
,
temporal_size
,
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_w
=
torch
.
linspace
(
start
[
1
],
stop
[
1
]
*
(
grid_size_w
-
1
)
/
grid_size_w
,
grid_size_w
,
device
=
device
,
dtype
=
torch
.
float32
)
grid_t
=
torch
.
arange
(
temporal_size
,
device
=
device
,
dtype
=
torch
.
float32
)
grid_t
=
torch
.
linspace
(
0
,
temporal_size
*
(
temporal_size
-
1
)
/
temporal_size
,
temporal_size
,
device
=
device
,
dtype
=
torch
.
float32
)
elif
grid_type
==
"slice"
:
elif
grid_type
==
"slice"
:
max_h
,
max_w
=
max_size
max_h
,
max_w
=
max_size
grid_size_h
,
grid_size_w
=
grid_size
grid_size_h
,
grid_size_w
=
grid_size
grid_h
=
np
.
arange
(
max_h
,
d
type
=
np
.
float32
)
grid_h
=
torch
.
arange
(
max_h
,
d
evice
=
device
,
dtype
=
torch
.
float32
)
grid_w
=
np
.
arange
(
max_w
,
d
type
=
np
.
float32
)
grid_w
=
torch
.
arange
(
max_w
,
d
evice
=
device
,
dtype
=
torch
.
float32
)
grid_t
=
np
.
arange
(
temporal_size
,
d
type
=
np
.
float32
)
grid_t
=
torch
.
arange
(
temporal_size
,
d
evice
=
device
,
dtype
=
torch
.
float32
)
else
:
else
:
raise
ValueError
(
"Invalid value passed for `grid_type`."
)
raise
ValueError
(
"Invalid value passed for `grid_type`."
)
...
@@ -640,10 +647,10 @@ def get_3d_rotary_pos_embed(
...
@@ -640,10 +647,10 @@ def get_3d_rotary_pos_embed(
dim_w
=
embed_dim
//
8
*
3
dim_w
=
embed_dim
//
8
*
3
# Temporal frequencies
# Temporal frequencies
freqs_t
=
get_1d_rotary_pos_embed
(
dim_t
,
grid_t
,
use_real
=
True
)
freqs_t
=
get_1d_rotary_pos_embed
(
dim_t
,
grid_t
,
theta
=
theta
,
use_real
=
True
)
# Spatial frequencies for height and width
# Spatial frequencies for height and width
freqs_h
=
get_1d_rotary_pos_embed
(
dim_h
,
grid_h
,
use_real
=
True
)
freqs_h
=
get_1d_rotary_pos_embed
(
dim_h
,
grid_h
,
theta
=
theta
,
use_real
=
True
)
freqs_w
=
get_1d_rotary_pos_embed
(
dim_w
,
grid_w
,
use_real
=
True
)
freqs_w
=
get_1d_rotary_pos_embed
(
dim_w
,
grid_w
,
theta
=
theta
,
use_real
=
True
)
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
def
combine_time_height_width
(
freqs_t
,
freqs_h
,
freqs_w
):
def
combine_time_height_width
(
freqs_t
,
freqs_h
,
freqs_w
):
...
@@ -686,14 +693,21 @@ def get_3d_rotary_pos_embed_allegro(
...
@@ -686,14 +693,21 @@ def get_3d_rotary_pos_embed_allegro(
temporal_size
,
temporal_size
,
interpolation_scale
:
Tuple
[
float
,
float
,
float
]
=
(
1.0
,
1.0
,
1.0
),
interpolation_scale
:
Tuple
[
float
,
float
,
float
]
=
(
1.0
,
1.0
,
1.0
),
theta
:
int
=
10000
,
theta
:
int
=
10000
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
# TODO(aryan): docs
# TODO(aryan): docs
start
,
stop
=
crops_coords
start
,
stop
=
crops_coords
grid_size_h
,
grid_size_w
=
grid_size
grid_size_h
,
grid_size_w
=
grid_size
interpolation_scale_t
,
interpolation_scale_h
,
interpolation_scale_w
=
interpolation_scale
interpolation_scale_t
,
interpolation_scale_h
,
interpolation_scale_w
=
interpolation_scale
grid_t
=
np
.
linspace
(
0
,
temporal_size
,
temporal_size
,
endpoint
=
False
,
dtype
=
np
.
float32
)
grid_t
=
torch
.
linspace
(
grid_h
=
np
.
linspace
(
start
[
0
],
stop
[
0
],
grid_size_h
,
endpoint
=
False
,
dtype
=
np
.
float32
)
0
,
temporal_size
*
(
temporal_size
-
1
)
/
temporal_size
,
temporal_size
,
device
=
device
,
dtype
=
torch
.
float32
grid_w
=
np
.
linspace
(
start
[
1
],
stop
[
1
],
grid_size_w
,
endpoint
=
False
,
dtype
=
np
.
float32
)
)
grid_h
=
torch
.
linspace
(
start
[
0
],
stop
[
0
]
*
(
grid_size_h
-
1
)
/
grid_size_h
,
grid_size_h
,
device
=
device
,
dtype
=
torch
.
float32
)
grid_w
=
torch
.
linspace
(
start
[
1
],
stop
[
1
]
*
(
grid_size_w
-
1
)
/
grid_size_w
,
grid_size_w
,
device
=
device
,
dtype
=
torch
.
float32
)
# Compute dimensions for each axis
# Compute dimensions for each axis
dim_t
=
embed_dim
//
3
dim_t
=
embed_dim
//
3
...
...
src/diffusers/pipelines/allegro/pipeline_allegro.py
View file @
4c4b323c
...
@@ -623,20 +623,17 @@ class AllegroPipeline(DiffusionPipeline):
...
@@ -623,20 +623,17 @@ class AllegroPipeline(DiffusionPipeline):
self
.
transformer
.
config
.
interpolation_scale_h
,
self
.
transformer
.
config
.
interpolation_scale_h
,
self
.
transformer
.
config
.
interpolation_scale_w
,
self
.
transformer
.
config
.
interpolation_scale_w
,
),
),
device
=
device
,
)
)
grid_t
=
torch
.
from_numpy
(
grid_t
)
.
to
(
device
=
device
,
dtype
=
torch
.
long
)
grid_t
=
grid_t
.
to
(
dtype
=
torch
.
long
)
grid_h
=
torch
.
from_numpy
(
grid_h
)
.
to
(
device
=
device
,
dtype
=
torch
.
long
)
grid_h
=
grid_h
.
to
(
dtype
=
torch
.
long
)
grid_w
=
torch
.
from_numpy
(
grid_w
)
.
to
(
device
=
device
,
dtype
=
torch
.
long
)
grid_w
=
grid_w
.
to
(
dtype
=
torch
.
long
)
pos
=
torch
.
cartesian_prod
(
grid_t
,
grid_h
,
grid_w
)
pos
=
torch
.
cartesian_prod
(
grid_t
,
grid_h
,
grid_w
)
pos
=
pos
.
reshape
(
-
1
,
3
).
transpose
(
0
,
1
).
reshape
(
3
,
1
,
-
1
).
contiguous
()
pos
=
pos
.
reshape
(
-
1
,
3
).
transpose
(
0
,
1
).
reshape
(
3
,
1
,
-
1
).
contiguous
()
grid_t
,
grid_h
,
grid_w
=
pos
grid_t
,
grid_h
,
grid_w
=
pos
freqs_t
=
(
freqs_t
[
0
].
to
(
device
=
device
),
freqs_t
[
1
].
to
(
device
=
device
))
freqs_h
=
(
freqs_h
[
0
].
to
(
device
=
device
),
freqs_h
[
1
].
to
(
device
=
device
))
freqs_w
=
(
freqs_w
[
0
].
to
(
device
=
device
),
freqs_w
[
1
].
to
(
device
=
device
))
return
(
freqs_t
,
freqs_h
,
freqs_w
),
(
grid_t
,
grid_h
,
grid_w
)
return
(
freqs_t
,
freqs_h
,
freqs_w
),
(
grid_t
,
grid_h
,
grid_w
)
@
property
@
property
...
...
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
View file @
4c4b323c
...
@@ -459,6 +459,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
...
@@ -459,6 +459,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
crops_coords
=
grid_crops_coords
,
crops_coords
=
grid_crops_coords
,
grid_size
=
(
grid_height
,
grid_width
),
grid_size
=
(
grid_height
,
grid_width
),
temporal_size
=
num_frames
,
temporal_size
=
num_frames
,
device
=
device
,
)
)
else
:
else
:
# CogVideoX 1.5
# CogVideoX 1.5
...
@@ -471,10 +472,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
...
@@ -471,10 +472,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
temporal_size
=
base_num_frames
,
temporal_size
=
base_num_frames
,
grid_type
=
"slice"
,
grid_type
=
"slice"
,
max_size
=
(
base_size_height
,
base_size_width
),
max_size
=
(
base_size_height
,
base_size_width
),
device
=
device
,
)
)
freqs_cos
=
freqs_cos
.
to
(
device
=
device
)
freqs_sin
=
freqs_sin
.
to
(
device
=
device
)
return
freqs_cos
,
freqs_sin
return
freqs_cos
,
freqs_sin
@
property
@
property
...
...
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
View file @
4c4b323c
...
@@ -505,6 +505,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
...
@@ -505,6 +505,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
crops_coords
=
grid_crops_coords
,
crops_coords
=
grid_crops_coords
,
grid_size
=
(
grid_height
,
grid_width
),
grid_size
=
(
grid_height
,
grid_width
),
temporal_size
=
num_frames
,
temporal_size
=
num_frames
,
device
=
device
,
)
)
else
:
else
:
# CogVideoX 1.5
# CogVideoX 1.5
...
@@ -517,10 +518,9 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
...
@@ -517,10 +518,9 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
temporal_size
=
base_num_frames
,
temporal_size
=
base_num_frames
,
grid_type
=
"slice"
,
grid_type
=
"slice"
,
max_size
=
(
base_size_height
,
base_size_width
),
max_size
=
(
base_size_height
,
base_size_width
),
device
=
device
,
)
)
freqs_cos
=
freqs_cos
.
to
(
device
=
device
)
freqs_sin
=
freqs_sin
.
to
(
device
=
device
)
return
freqs_cos
,
freqs_sin
return
freqs_cos
,
freqs_sin
@
property
@
property
...
...
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
View file @
4c4b323c
...
@@ -555,6 +555,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
...
@@ -555,6 +555,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
crops_coords
=
grid_crops_coords
,
crops_coords
=
grid_crops_coords
,
grid_size
=
(
grid_height
,
grid_width
),
grid_size
=
(
grid_height
,
grid_width
),
temporal_size
=
num_frames
,
temporal_size
=
num_frames
,
device
=
device
,
)
)
else
:
else
:
# CogVideoX 1.5
# CogVideoX 1.5
...
@@ -567,10 +568,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
...
@@ -567,10 +568,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
temporal_size
=
base_num_frames
,
temporal_size
=
base_num_frames
,
grid_type
=
"slice"
,
grid_type
=
"slice"
,
max_size
=
(
base_size_height
,
base_size_width
),
max_size
=
(
base_size_height
,
base_size_width
),
device
=
device
,
)
)
freqs_cos
=
freqs_cos
.
to
(
device
=
device
)
freqs_sin
=
freqs_sin
.
to
(
device
=
device
)
return
freqs_cos
,
freqs_sin
return
freqs_cos
,
freqs_sin
@
property
@
property
...
...
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
View file @
4c4b323c
...
@@ -529,6 +529,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
...
@@ -529,6 +529,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
crops_coords
=
grid_crops_coords
,
crops_coords
=
grid_crops_coords
,
grid_size
=
(
grid_height
,
grid_width
),
grid_size
=
(
grid_height
,
grid_width
),
temporal_size
=
num_frames
,
temporal_size
=
num_frames
,
device
=
device
,
)
)
else
:
else
:
# CogVideoX 1.5
# CogVideoX 1.5
...
@@ -541,10 +542,9 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
...
@@ -541,10 +542,9 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
temporal_size
=
base_num_frames
,
temporal_size
=
base_num_frames
,
grid_type
=
"slice"
,
grid_type
=
"slice"
,
max_size
=
(
base_size_height
,
base_size_width
),
max_size
=
(
base_size_height
,
base_size_width
),
device
=
device
,
)
)
freqs_cos
=
freqs_cos
.
to
(
device
=
device
)
freqs_sin
=
freqs_sin
.
to
(
device
=
device
)
return
freqs_cos
,
freqs_sin
return
freqs_cos
,
freqs_sin
@
property
@
property
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment