Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
0071478d
Unverified
Commit
0071478d
authored
Feb 09, 2024
by
YiYi Xu
Committed by
GitHub
Feb 09, 2024
Browse files
allow attention processors to have different signatures (#6915)
add Co-authored-by:
yiyixuxu
<
yixu310@gmail,com
>
parent
7c8cab31
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
0 deletions
+10
-0
src/diffusers/models/attention_processor.py
src/diffusers/models/attention_processor.py
+10
-0
No files found.
src/diffusers/models/attention_processor.py
View file @
0071478d
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
inspect
from
importlib
import
import_module
from
importlib
import
import_module
from
typing
import
Callable
,
Optional
,
Union
from
typing
import
Callable
,
Optional
,
Union
...
@@ -509,6 +510,15 @@ class Attention(nn.Module):
...
@@ -509,6 +510,15 @@ class Attention(nn.Module):
# The `Attention` class can call different attention processors / attention functions
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters
=
set
(
inspect
.
signature
(
self
.
processor
.
__call__
).
parameters
.
keys
())
unused_kwargs
=
[
k
for
k
,
_
in
cross_attention_kwargs
.
items
()
if
k
not
in
attn_parameters
]
if
len
(
unused_kwargs
)
>
0
:
logger
.
warning
(
f
"cross_attention_kwargs
{
unused_kwargs
}
are not expected by
{
self
.
processor
.
__class__
.
__name__
}
and will be ignored."
)
cross_attention_kwargs
=
{
k
:
w
for
k
,
w
in
cross_attention_kwargs
.
items
()
if
k
in
attn_parameters
}
return
self
.
processor
(
return
self
.
processor
(
self
,
self
,
hidden_states
,
hidden_states
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment