Commit 2548c728 authored by lijian6's avatar lijian6
Browse files

Add flash attention for sd3 medium


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent 0a4e78fc
Pipeline #1257 failed with stages
in 0 seconds
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
...@@ -119,7 +120,7 @@ class JointTransformerBlock(nn.Module): ...@@ -119,7 +120,7 @@ class JointTransformerBlock(nn.Module):
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
) )
if hasattr(F, "scaled_dot_product_attention"): if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0() self.processor = JointAttnProcessor2_0()
else: else:
raise ValueError( raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function." "The current PyTorch version does not support the `scaled_dot_product_attention` function."
...@@ -133,7 +134,7 @@ class JointTransformerBlock(nn.Module): ...@@ -133,7 +134,7 @@ class JointTransformerBlock(nn.Module):
out_dim=attention_head_dim, out_dim=attention_head_dim,
context_pre_only=context_pre_only, context_pre_only=context_pre_only,
bias=True, bias=True,
processor=processor, processor=self.processor,
) )
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
...@@ -169,6 +170,9 @@ class JointTransformerBlock(nn.Module): ...@@ -169,6 +170,9 @@ class JointTransformerBlock(nn.Module):
) )
# Attention. # Attention.
use_xformers = os.getenv('USE_XFORMERS', '0')
if use_xformers == '1':
self.attn.set_processor(self.processor)
attn_output, context_attn_output = self.attn( attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
) )
......
...@@ -16,6 +16,7 @@ import math ...@@ -16,6 +16,7 @@ import math
from importlib import import_module from importlib import import_module
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
...@@ -25,6 +26,7 @@ from ..utils import deprecate, logging ...@@ -25,6 +26,7 @@ from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.import_utils import is_torch_npu_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRALinearLayer from .lora import LoRALinearLayer
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp, MemoryEfficientAttentionTritonFwdFlashBwOp
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -1127,17 +1129,29 @@ class JointAttnProcessor2_0: ...@@ -1127,17 +1129,29 @@ class JointAttnProcessor2_0:
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
inner_dim = key.shape[-1] use_xformers = os.getenv('USE_XFORMERS', '0')
head_dim = inner_dim // attn.heads if use_xformers == '1':
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) query = attn.head_to_batch_dim(query).contiguous()
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = attn.head_to_batch_dim(key).contiguous()
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = attn.head_to_batch_dim(value).contiguous()
hidden_states = hidden_states = F.scaled_dot_product_attention( hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, dropout_p=0.0, is_causal=False query, key, value, op=MemoryEfficientAttentionTritonFwdFlashBwOp
) )
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype)
hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states)
else:
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs. # Split the attention outputs.
hidden_states, encoder_hidden_states = ( hidden_states, encoder_hidden_states = (
......
...@@ -36,7 +36,7 @@ from ...utils import ( ...@@ -36,7 +36,7 @@ from ...utils import (
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import StableDiffusion3PipelineOutput from .pipeline_output import StableDiffusion3PipelineOutput
import os
if is_torch_xla_available(): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
...@@ -868,6 +868,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -868,6 +868,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
use_xformers = os.getenv('USE_XFORMERS', '0')
if use_xformers == '1':
self.disable_xformers_memory_efficient_attention()
if output_type == "latent": if output_type == "latent":
image = latents image = latents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment