Commit a3528339 authored by lijian6's avatar lijian6
Browse files

Add sd3 triton fa.


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent 8e4e71c8
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
import os
import torch
import torch.nn.functional as F
from torch import nn
......@@ -119,7 +119,7 @@ class JointTransformerBlock(nn.Module):
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
)
if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
self.processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
......@@ -133,7 +133,7 @@ class JointTransformerBlock(nn.Module):
out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
processor=processor,
processor=self.processor,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
......@@ -169,6 +169,9 @@ class JointTransformerBlock(nn.Module):
)
# Attention.
sd3_use_xformers = os.getenv('SD3_USE_XFORMERS', '0')
if sd3_use_xformers == '1':
self.attn.set_processor(self.processor)
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
)
......
......@@ -1076,15 +1076,27 @@ class JointAttnProcessor2_0:
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
sd3_use_xformers = os.getenv('SD3_USE_XFORMERS', '0')
if sd3_use_xformers == '1':
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, op=MemoryEfficientAttentionTritonFwdFlashBwOp
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
else:
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
......
......@@ -14,7 +14,7 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import os
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
......
......@@ -917,6 +917,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
if XLA_AVAILABLE:
xm.mark_step()
sd3_use_xformers = os.getenv('SD3_USE_XFORMERS', '0')
if sd3_use_xformers == '1':
self.disable_xformers_memory_efficient_attention()
if output_type == "latent":
image = latents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment