Commit 2548c728 authored by lijian6's avatar lijian6
Browse files

Add flash attention for sd3 medium


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent 0a4e78fc
Pipeline #1257 failed with stages
in 0 seconds
......@@ -13,6 +13,7 @@
# limitations under the License.
from typing import Any, Dict, Optional
import os
import torch
import torch.nn.functional as F
from torch import nn
......@@ -119,7 +120,7 @@ class JointTransformerBlock(nn.Module):
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
)
if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
self.processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
......@@ -133,7 +134,7 @@ class JointTransformerBlock(nn.Module):
out_dim=attention_head_dim,
context_pre_only=context_pre_only,
bias=True,
processor=processor,
processor=self.processor,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
......@@ -169,6 +170,9 @@ class JointTransformerBlock(nn.Module):
)
# Attention.
use_xformers = os.getenv('USE_XFORMERS', '0')
if use_xformers == '1':
self.attn.set_processor(self.processor)
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
)
......
......@@ -16,6 +16,7 @@ import math
from importlib import import_module
from typing import Callable, List, Optional, Union
import os
import torch
import torch.nn.functional as F
from torch import nn
......@@ -25,6 +26,7 @@ from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRALinearLayer
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp, MemoryEfficientAttentionTritonFwdFlashBwOp
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -1127,6 +1129,18 @@ class JointAttnProcessor2_0:
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
use_xformers = os.getenv('USE_XFORMERS', '0')
if use_xformers == '1':
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, op=MemoryEfficientAttentionTritonFwdFlashBwOp
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
else:
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
......
......@@ -36,7 +36,7 @@ from ...utils import (
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import StableDiffusion3PipelineOutput
import os
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
......@@ -868,6 +868,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
if XLA_AVAILABLE:
xm.mark_step()
use_xformers = os.getenv('USE_XFORMERS', '0')
if use_xformers == '1':
self.disable_xformers_memory_efficient_attention()
if output_type == "latent":
image = latents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment