Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
2548c728
Commit
2548c728
authored
Jun 24, 2024
by
lijian6
Browse files
Add flash attention for sd3 medium
Signed-off-by:
lijian
<
lijian6@sugon.com
>
parent
0a4e78fc
Pipeline
#1257
failed with stages
in 0 seconds
Changes
3
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
13 deletions
+34
-13
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+6
-2
src/diffusers/models/attention_processor.py
src/diffusers/models/attention_processor.py
+24
-10
src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
...pelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+4
-1
No files found.
src/diffusers/models/attention.py
View file @
2548c728
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
import
os
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
...
@@ -119,7 +120,7 @@ class JointTransformerBlock(nn.Module):
...
@@ -119,7 +120,7 @@ class JointTransformerBlock(nn.Module):
f
"Unknown context_norm_type:
{
context_norm_type
}
, currently only support `ada_norm_continous`, `ada_norm_zero`"
f
"Unknown context_norm_type:
{
context_norm_type
}
, currently only support `ada_norm_continous`, `ada_norm_zero`"
)
)
if
hasattr
(
F
,
"scaled_dot_product_attention"
):
if
hasattr
(
F
,
"scaled_dot_product_attention"
):
processor
=
JointAttnProcessor2_0
()
self
.
processor
=
JointAttnProcessor2_0
()
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
...
@@ -133,7 +134,7 @@ class JointTransformerBlock(nn.Module):
...
@@ -133,7 +134,7 @@ class JointTransformerBlock(nn.Module):
out_dim
=
attention_head_dim
,
out_dim
=
attention_head_dim
,
context_pre_only
=
context_pre_only
,
context_pre_only
=
context_pre_only
,
bias
=
True
,
bias
=
True
,
processor
=
processor
,
processor
=
self
.
processor
,
)
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
elementwise_affine
=
False
,
eps
=
1e-6
)
...
@@ -169,6 +170,9 @@ class JointTransformerBlock(nn.Module):
...
@@ -169,6 +170,9 @@ class JointTransformerBlock(nn.Module):
)
)
# Attention.
# Attention.
use_xformers
=
os
.
getenv
(
'USE_XFORMERS'
,
'0'
)
if
use_xformers
==
'1'
:
self
.
attn
.
set_processor
(
self
.
processor
)
attn_output
,
context_attn_output
=
self
.
attn
(
attn_output
,
context_attn_output
=
self
.
attn
(
hidden_states
=
norm_hidden_states
,
encoder_hidden_states
=
norm_encoder_hidden_states
hidden_states
=
norm_hidden_states
,
encoder_hidden_states
=
norm_encoder_hidden_states
)
)
...
...
src/diffusers/models/attention_processor.py
View file @
2548c728
...
@@ -16,6 +16,7 @@ import math
...
@@ -16,6 +16,7 @@ import math
from
importlib
import
import_module
from
importlib
import
import_module
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Union
import
os
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
...
@@ -25,6 +26,7 @@ from ..utils import deprecate, logging
...
@@ -25,6 +26,7 @@ from ..utils import deprecate, logging
from
..utils.import_utils
import
is_torch_npu_available
,
is_xformers_available
from
..utils.import_utils
import
is_torch_npu_available
,
is_xformers_available
from
..utils.torch_utils
import
maybe_allow_in_graph
from
..utils.torch_utils
import
maybe_allow_in_graph
from
.lora
import
LoRALinearLayer
from
.lora
import
LoRALinearLayer
from
xformers.ops
import
MemoryEfficientAttentionFlashAttentionOp
,
MemoryEfficientAttentionTritonFwdFlashBwOp
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -1127,6 +1129,18 @@ class JointAttnProcessor2_0:
...
@@ -1127,6 +1129,18 @@ class JointAttnProcessor2_0:
key
=
torch
.
cat
([
key
,
encoder_hidden_states_key_proj
],
dim
=
1
)
key
=
torch
.
cat
([
key
,
encoder_hidden_states_key_proj
],
dim
=
1
)
value
=
torch
.
cat
([
value
,
encoder_hidden_states_value_proj
],
dim
=
1
)
value
=
torch
.
cat
([
value
,
encoder_hidden_states_value_proj
],
dim
=
1
)
use_xformers
=
os
.
getenv
(
'USE_XFORMERS'
,
'0'
)
if
use_xformers
==
'1'
:
query
=
attn
.
head_to_batch_dim
(
query
).
contiguous
()
key
=
attn
.
head_to_batch_dim
(
key
).
contiguous
()
value
=
attn
.
head_to_batch_dim
(
value
).
contiguous
()
hidden_states
=
xformers
.
ops
.
memory_efficient_attention
(
query
,
key
,
value
,
op
=
MemoryEfficientAttentionTritonFwdFlashBwOp
)
hidden_states
=
hidden_states
.
to
(
query
.
dtype
)
hidden_states
=
attn
.
batch_to_head_dim
(
hidden_states
)
else
:
inner_dim
=
key
.
shape
[
-
1
]
inner_dim
=
key
.
shape
[
-
1
]
head_dim
=
inner_dim
//
attn
.
heads
head_dim
=
inner_dim
//
attn
.
heads
query
=
query
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
query
=
query
.
view
(
batch_size
,
-
1
,
attn
.
heads
,
head_dim
).
transpose
(
1
,
2
)
...
...
src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
View file @
2548c728
...
@@ -36,7 +36,7 @@ from ...utils import (
...
@@ -36,7 +36,7 @@ from ...utils import (
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
from
..pipeline_utils
import
DiffusionPipeline
from
.pipeline_output
import
StableDiffusion3PipelineOutput
from
.pipeline_output
import
StableDiffusion3PipelineOutput
import
os
if
is_torch_xla_available
():
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
...
@@ -868,6 +868,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
...
@@ -868,6 +868,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
if
XLA_AVAILABLE
:
if
XLA_AVAILABLE
:
xm
.
mark_step
()
xm
.
mark_step
()
use_xformers
=
os
.
getenv
(
'USE_XFORMERS'
,
'0'
)
if
use_xformers
==
'1'
:
self
.
disable_xformers_memory_efficient_attention
()
if
output_type
==
"latent"
:
if
output_type
==
"latent"
:
image
=
latents
image
=
latents
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment