Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
0ac52d6f
Unverified
Commit
0ac52d6f
authored
Dec 18, 2024
by
hlky
Committed by
GitHub
Dec 17, 2024
Browse files
Use `torch` in `get_2d_rotary_pos_embed` (#10155)
* Use `torch` in `get_2d_rotary_pos_embed` * Add deprecation
parent
ba6fd6eb
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
68 additions
and
4 deletions
+68
-4
examples/community/pipeline_hunyuandit_differential_img2img.py
...les/community/pipeline_hunyuandit_differential_img2img.py
+2
-0
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+51
-1
src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
...s/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
+5
-1
src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
+5
-1
src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
+5
-1
No files found.
examples/community/pipeline_hunyuandit_differential_img2img.py
View file @
0ac52d6f
...
...
@@ -1008,6 +1008,8 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
self
.
transformer
.
inner_dim
//
self
.
transformer
.
num_heads
,
grid_crops_coords
,
(
grid_height
,
grid_width
),
device
=
device
,
output_type
=
"pt"
,
)
style
=
torch
.
tensor
([
0
],
device
=
device
)
...
...
src/diffusers/models/embeddings.py
View file @
0ac52d6f
...
...
@@ -957,7 +957,57 @@ def get_3d_rotary_pos_embed_allegro(
return
freqs_t
,
freqs_h
,
freqs_w
,
grid_t
,
grid_h
,
grid_w
def
get_2d_rotary_pos_embed
(
embed_dim
,
crops_coords
,
grid_size
,
use_real
=
True
):
def
get_2d_rotary_pos_embed
(
embed_dim
,
crops_coords
,
grid_size
,
use_real
=
True
,
device
:
Optional
[
torch
.
device
]
=
None
,
output_type
:
str
=
"np"
):
"""
RoPE for image tokens with 2d structure.
Args:
embed_dim: (`int`):
The embedding dimension size
crops_coords (`Tuple[int]`)
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the positional embedding.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
device: (`torch.device`, **optional**):
The device used to create tensors.
Returns:
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
"""
if
output_type
==
"np"
:
deprecation_message
=
(
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate
(
"output_type=='np'"
,
"0.33.0"
,
deprecation_message
,
standard_warn
=
False
)
return
_get_2d_rotary_pos_embed_np
(
embed_dim
=
embed_dim
,
crops_coords
=
crops_coords
,
grid_size
=
grid_size
,
use_real
=
use_real
,
)
start
,
stop
=
crops_coords
# scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
grid_h
=
torch
.
linspace
(
start
[
0
],
stop
[
0
]
*
(
grid_size
[
0
]
-
1
)
/
grid_size
[
0
],
grid_size
[
0
],
device
=
device
,
dtype
=
torch
.
float32
)
grid_w
=
torch
.
linspace
(
start
[
1
],
stop
[
1
]
*
(
grid_size
[
1
]
-
1
)
/
grid_size
[
1
],
grid_size
[
1
],
device
=
device
,
dtype
=
torch
.
float32
)
grid
=
torch
.
meshgrid
(
grid_w
,
grid_h
,
indexing
=
"xy"
)
grid
=
torch
.
stack
(
grid
,
dim
=
0
)
# [2, W, H]
grid
=
grid
.
reshape
([
2
,
1
,
*
grid
.
shape
[
1
:]])
pos_embed
=
get_2d_rotary_pos_embed_from_grid
(
embed_dim
,
grid
,
use_real
=
use_real
)
return
pos_embed
def
_get_2d_rotary_pos_embed_np
(
embed_dim
,
crops_coords
,
grid_size
,
use_real
=
True
):
"""
RoPE for image tokens with 2d structure.
...
...
src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
View file @
0ac52d6f
...
...
@@ -925,7 +925,11 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
base_size
=
512
//
8
//
self
.
transformer
.
config
.
patch_size
grid_crops_coords
=
get_resize_crop_region_for_grid
((
grid_height
,
grid_width
),
base_size
)
image_rotary_emb
=
get_2d_rotary_pos_embed
(
self
.
transformer
.
inner_dim
//
self
.
transformer
.
num_heads
,
grid_crops_coords
,
(
grid_height
,
grid_width
)
self
.
transformer
.
inner_dim
//
self
.
transformer
.
num_heads
,
grid_crops_coords
,
(
grid_height
,
grid_width
),
device
=
device
,
output_type
=
"pt"
,
)
style
=
torch
.
tensor
([
0
],
device
=
device
)
...
...
src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
View file @
0ac52d6f
...
...
@@ -798,7 +798,11 @@ class HunyuanDiTPipeline(DiffusionPipeline):
base_size
=
512
//
8
//
self
.
transformer
.
config
.
patch_size
grid_crops_coords
=
get_resize_crop_region_for_grid
((
grid_height
,
grid_width
),
base_size
)
image_rotary_emb
=
get_2d_rotary_pos_embed
(
self
.
transformer
.
inner_dim
//
self
.
transformer
.
num_heads
,
grid_crops_coords
,
(
grid_height
,
grid_width
)
self
.
transformer
.
inner_dim
//
self
.
transformer
.
num_heads
,
grid_crops_coords
,
(
grid_height
,
grid_width
),
device
=
device
,
output_type
=
"pt"
,
)
style
=
torch
.
tensor
([
0
],
device
=
device
)
...
...
src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
View file @
0ac52d6f
...
...
@@ -818,7 +818,11 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
base_size
=
512
//
8
//
self
.
transformer
.
config
.
patch_size
grid_crops_coords
=
get_resize_crop_region_for_grid
((
grid_height
,
grid_width
),
base_size
)
image_rotary_emb
=
get_2d_rotary_pos_embed
(
self
.
transformer
.
inner_dim
//
self
.
transformer
.
num_heads
,
grid_crops_coords
,
(
grid_height
,
grid_width
)
self
.
transformer
.
inner_dim
//
self
.
transformer
.
num_heads
,
grid_crops_coords
,
(
grid_height
,
grid_width
),
device
=
device
,
output_type
=
"pt"
,
)
style
=
torch
.
tensor
([
0
],
device
=
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment