Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
093cd3f0
Unverified
Commit
093cd3f0
authored
Nov 11, 2025
by
YiYi Xu
Committed by
GitHub
Nov 11, 2025
Browse files
fix dispatch_attention_fn check (#12636)
* fix * fix
parent
aecf0c53
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
6 deletions
+12
-6
src/diffusers/models/attention_dispatch.py
src/diffusers/models/attention_dispatch.py
+11
-5
src/diffusers/utils/constants.py
src/diffusers/utils/constants.py
+1
-1
No files found.
src/diffusers/models/attention_dispatch.py
View file @
093cd3f0
...
@@ -383,12 +383,18 @@ def _check_shape(
...
@@ -383,12 +383,18 @@ def _check_shape(
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
# Expected shapes:
# query: (batch_size, seq_len_q, num_heads, head_dim)
# key: (batch_size, seq_len_kv, num_heads, head_dim)
# value: (batch_size, seq_len_kv, num_heads, head_dim)
# attn_mask: (seq_len_q, seq_len_kv) or (batch_size, seq_len_q, seq_len_kv)
# or (batch_size, num_heads, seq_len_q, seq_len_kv)
if
query
.
shape
[
-
1
]
!=
key
.
shape
[
-
1
]:
if
query
.
shape
[
-
1
]
!=
key
.
shape
[
-
1
]:
raise
ValueError
(
"Query and key must have the same
last
dimension."
)
raise
ValueError
(
"Query and key must have the same
head
dimension."
)
if
quer
y
.
shape
[
-
2
]
!=
value
.
shape
[
-
2
]:
if
ke
y
.
shape
[
-
3
]
!=
value
.
shape
[
-
3
]:
raise
ValueError
(
"
Quer
y and value must have the same se
cond to last dimension
."
)
raise
ValueError
(
"
Ke
y and value must have the same se
quence length
."
)
if
attn_mask
is
not
None
and
attn_mask
.
shape
[
-
1
]
!=
key
.
shape
[
-
2
]:
if
attn_mask
is
not
None
and
attn_mask
.
shape
[
-
1
]
!=
key
.
shape
[
-
3
]:
raise
ValueError
(
"Attention mask must match the key's se
cond to last dimension
."
)
raise
ValueError
(
"Attention mask must match the key's se
quence length
."
)
# ===== Helper functions =====
# ===== Helper functions =====
...
...
src/diffusers/utils/constants.py
View file @
093cd3f0
...
@@ -42,7 +42,7 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"
...
@@ -42,7 +42,7 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"
DEPRECATED_REVISION_ARGS
=
[
"fp16"
,
"non-ema"
]
DEPRECATED_REVISION_ARGS
=
[
"fp16"
,
"non-ema"
]
DIFFUSERS_REQUEST_TIMEOUT
=
60
DIFFUSERS_REQUEST_TIMEOUT
=
60
DIFFUSERS_ATTN_BACKEND
=
os
.
getenv
(
"DIFFUSERS_ATTN_BACKEND"
,
"native"
)
DIFFUSERS_ATTN_BACKEND
=
os
.
getenv
(
"DIFFUSERS_ATTN_BACKEND"
,
"native"
)
DIFFUSERS_ATTN_CHECKS
=
os
.
getenv
(
"DIFFUSERS_ATTN_CHECKS"
,
"0"
)
in
ENV_VARS_TRUE_VALUES
DIFFUSERS_ATTN_CHECKS
=
os
.
getenv
(
"DIFFUSERS_ATTN_CHECKS"
,
"0"
)
.
upper
()
in
ENV_VARS_TRUE_VALUES
DEFAULT_HF_PARALLEL_LOADING_WORKERS
=
8
DEFAULT_HF_PARALLEL_LOADING_WORKERS
=
8
HF_ENABLE_PARALLEL_LOADING
=
os
.
environ
.
get
(
"HF_ENABLE_PARALLEL_LOADING"
,
""
).
upper
()
in
ENV_VARS_TRUE_VALUES
HF_ENABLE_PARALLEL_LOADING
=
os
.
environ
.
get
(
"HF_ENABLE_PARALLEL_LOADING"
,
""
).
upper
()
in
ENV_VARS_TRUE_VALUES
DIFFUSERS_DISABLE_REMOTE_CODE
=
os
.
getenv
(
"DIFFUSERS_DISABLE_REMOTE_CODE"
,
"false"
).
upper
()
in
ENV_VARS_TRUE_VALUES
DIFFUSERS_DISABLE_REMOTE_CODE
=
os
.
getenv
(
"DIFFUSERS_DISABLE_REMOTE_CODE"
,
"false"
).
upper
()
in
ENV_VARS_TRUE_VALUES
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment