Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
780b3a4f
Unverified
Commit
780b3a4f
authored
Feb 17, 2023
by
Pedro Cuenca
Committed by
GitHub
Feb 17, 2023
Browse files
Fix typo in AttnProcessor2_0 symbol (#2404)
Fix typo in AttnProcessor2_0 symbol.
parent
07547dfa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
6 deletions
+6
-6
docs/source/en/optimization/torch2.0.mdx
docs/source/en/optimization/torch2.0.mdx
+2
-2
src/diffusers/models/cross_attention.py
src/diffusers/models/cross_attention.py
+4
-4
No files found.
docs/source/en/optimization/torch2.0.mdx
View file @
780b3a4f
...
...
@@ -50,10 +50,10 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
```Python
import torch
from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProc
c
esor2_0
from diffusers.models.cross_attention import AttnProce
s
sor2_0
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
pipe.unet.set_attn_processor(AttnProc
c
esor2_0())
pipe.unet.set_attn_processor(AttnProce
s
sor2_0())
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
...
...
src/diffusers/models/cross_attention.py
View file @
780b3a4f
...
...
@@ -99,10 +99,10 @@ class CrossAttention(nn.Module):
self
.
to_out
.
append
(
nn
.
Dropout
(
dropout
))
# set attention processor
# We use the AttnProc
c
esor2_0 by default when torch2.x is used which uses
# We use the AttnProce
s
sor2_0 by default when torch2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
if
processor
is
None
:
processor
=
AttnProc
c
esor2_0
()
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
CrossAttnProcessor
()
processor
=
AttnProce
s
sor2_0
()
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
CrossAttnProcessor
()
self
.
set_processor
(
processor
)
def
set_use_memory_efficient_attention_xformers
(
...
...
@@ -466,10 +466,10 @@ class XFormersCrossAttnProcessor:
return
hidden_states
class
AttnProc
c
esor2_0
:
class
AttnProce
s
sor2_0
:
def
__init__
(
self
):
if
not
hasattr
(
F
,
"scaled_dot_product_attention"
):
raise
ImportError
(
"AttnProc
c
esor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
raise
ImportError
(
"AttnProce
s
sor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def
__call__
(
self
,
attn
:
CrossAttention
,
hidden_states
,
encoder_hidden_states
=
None
,
attention_mask
=
None
):
batch_size
,
sequence_length
,
inner_dim
=
hidden_states
.
shape
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment