Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
780b3a4f
Unverified
Commit
780b3a4f
authored
Feb 17, 2023
by
Pedro Cuenca
Committed by
GitHub
Feb 17, 2023
Browse files
Fix typo in AttnProcessor2_0 symbol (#2404)
Fix typo in AttnProcessor2_0 symbol.
parent
07547dfa
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
6 deletions
+6
-6
docs/source/en/optimization/torch2.0.mdx
docs/source/en/optimization/torch2.0.mdx
+2
-2
src/diffusers/models/cross_attention.py
src/diffusers/models/cross_attention.py
+4
-4
No files found.
docs/source/en/optimization/torch2.0.mdx
View file @
780b3a4f
...
...
@@ -50,10 +50,10 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
```Python
import torch
from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProc
c
esor2_0
from diffusers.models.cross_attention import AttnProce
s
sor2_0
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
pipe.unet.set_attn_processor(AttnProc
c
esor2_0())
pipe.unet.set_attn_processor(AttnProce
s
sor2_0())
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
...
...
src/diffusers/models/cross_attention.py
View file @
780b3a4f
...
...
@@ -99,10 +99,10 @@ class CrossAttention(nn.Module):
self
.
to_out
.
append
(
nn
.
Dropout
(
dropout
))
# set attention processor
# We use the AttnProc
c
esor2_0 by default when torch2.x is used which uses
# We use the AttnProce
s
sor2_0 by default when torch2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
if
processor
is
None
:
processor
=
AttnProc
c
esor2_0
()
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
CrossAttnProcessor
()
processor
=
AttnProce
s
sor2_0
()
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
CrossAttnProcessor
()
self
.
set_processor
(
processor
)
def
set_use_memory_efficient_attention_xformers
(
...
...
@@ -466,10 +466,10 @@ class XFormersCrossAttnProcessor:
return
hidden_states
class
AttnProc
c
esor2_0
:
class
AttnProce
s
sor2_0
:
def
__init__
(
self
):
if
not
hasattr
(
F
,
"scaled_dot_product_attention"
):
raise
ImportError
(
"AttnProc
c
esor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
raise
ImportError
(
"AttnProce
s
sor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def
__call__
(
self
,
attn
:
CrossAttention
,
hidden_states
,
encoder_hidden_states
=
None
,
attention_mask
=
None
):
batch_size
,
sequence_length
,
inner_dim
=
hidden_states
.
shape
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment