Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7853bfbe
Unverified
Commit
7853bfbe
authored
Oct 19, 2025
by
dg845
Committed by
GitHub
Oct 19, 2025
Browse files
Remove Qwen Image Redundant RoPE Cache (#12452)
Refactor QwenEmbedRope to only use the LRU cache for RoPE caching
parent
23ebbb4b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
14 deletions
+17
-14
src/diffusers/models/transformers/transformer_qwenimage.py
src/diffusers/models/transformers/transformer_qwenimage.py
+17
-14
No files found.
src/diffusers/models/transformers/transformer_qwenimage.py
View file @
7853bfbe
...
...
@@ -180,7 +180,6 @@ class QwenEmbedRope(nn.Module):
],
dim
=
1
,
)
self
.
rope_cache
=
{}
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
self
.
scale_rope
=
scale_rope
...
...
@@ -195,10 +194,20 @@ class QwenEmbedRope(nn.Module):
freqs
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
return
freqs
def
forward
(
self
,
video_fhw
,
txt_seq_lens
,
device
):
def
forward
(
self
,
video_fhw
:
Union
[
Tuple
[
int
,
int
,
int
],
List
[
Tuple
[
int
,
int
,
int
]]],
txt_seq_lens
:
List
[
int
],
device
:
torch
.
device
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
Args:
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
A list of 3 integers [frame, height, width] representing the shape of the video.
txt_seq_lens (`List[int]`):
A list of integers of length batch_size representing the length of each text prompt.
device: (`torch.device`):
The device on which to perform the RoPE computation.
"""
if
self
.
pos_freqs
.
device
!=
device
:
self
.
pos_freqs
=
self
.
pos_freqs
.
to
(
device
)
...
...
@@ -213,14 +222,8 @@ class QwenEmbedRope(nn.Module):
max_vid_index
=
0
for
idx
,
fhw
in
enumerate
(
video_fhw
):
frame
,
height
,
width
=
fhw
rope_key
=
f
"
{
idx
}
_
{
height
}
_
{
width
}
"
if
not
torch
.
compiler
.
is_compiling
():
if
rope_key
not
in
self
.
rope_cache
:
self
.
rope_cache
[
rope_key
]
=
self
.
_compute_video_freqs
(
frame
,
height
,
width
,
idx
)
video_freq
=
self
.
rope_cache
[
rope_key
]
else
:
video_freq
=
self
.
_compute_video_freqs
(
frame
,
height
,
width
,
idx
)
# RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
video_freq
=
self
.
_compute_video_freqs
(
frame
,
height
,
width
,
idx
)
video_freq
=
video_freq
.
to
(
device
)
vid_freqs
.
append
(
video_freq
)
...
...
@@ -235,8 +238,8 @@ class QwenEmbedRope(nn.Module):
return
vid_freqs
,
txt_freqs
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_compute_video_freqs
(
self
,
frame
,
height
,
width
,
idx
=
0
)
:
@
functools
.
lru_cache
(
maxsize
=
128
)
def
_compute_video_freqs
(
self
,
frame
:
int
,
height
:
int
,
width
:
int
,
idx
:
int
=
0
)
->
torch
.
Tensor
:
seq_lens
=
frame
*
height
*
width
freqs_pos
=
self
.
pos_freqs
.
split
([
x
//
2
for
x
in
self
.
axes_dim
],
dim
=
1
)
freqs_neg
=
self
.
neg_freqs
.
split
([
x
//
2
for
x
in
self
.
axes_dim
],
dim
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment