Commit a3528339 authored by lijian6's avatar lijian6
Browse files

Add sd3 triton fa.


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent 8e4e71c8
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
...@@ -119,7 +119,7 @@ class JointTransformerBlock(nn.Module): ...@@ -119,7 +119,7 @@ class JointTransformerBlock(nn.Module):
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
) )
if hasattr(F, "scaled_dot_product_attention"): if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0() self.processor = JointAttnProcessor2_0()
else: else:
raise ValueError( raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function." "The current PyTorch version does not support the `scaled_dot_product_attention` function."
...@@ -133,7 +133,7 @@ class JointTransformerBlock(nn.Module): ...@@ -133,7 +133,7 @@ class JointTransformerBlock(nn.Module):
out_dim=dim, out_dim=dim,
context_pre_only=context_pre_only, context_pre_only=context_pre_only,
bias=True, bias=True,
processor=processor, processor=self.processor,
) )
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
...@@ -169,6 +169,9 @@ class JointTransformerBlock(nn.Module): ...@@ -169,6 +169,9 @@ class JointTransformerBlock(nn.Module):
) )
# Attention. # Attention.
sd3_use_xformers = os.getenv('SD3_USE_XFORMERS', '0')
if sd3_use_xformers == '1':
self.attn.set_processor(self.processor)
attn_output, context_attn_output = self.attn( attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
) )
......
...@@ -1076,15 +1076,27 @@ class JointAttnProcessor2_0: ...@@ -1076,15 +1076,27 @@ class JointAttnProcessor2_0:
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
inner_dim = key.shape[-1] sd3_use_xformers = os.getenv('SD3_USE_XFORMERS', '0')
head_dim = inner_dim // attn.heads if sd3_use_xformers == '1':
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) query = attn.head_to_batch_dim(query).contiguous()
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = attn.head_to_batch_dim(key).contiguous()
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = attn.head_to_batch_dim(value).contiguous()
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = xformers.ops.memory_efficient_attention(
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) query, key, value, op=MemoryEfficientAttentionTritonFwdFlashBwOp
hidden_states = hidden_states.to(query.dtype) )
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
else:
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs. # Split the attention outputs.
hidden_states, encoder_hidden_states = ( hidden_states, encoder_hidden_states = (
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import os
import numpy as np import numpy as np
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
......
...@@ -917,6 +917,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -917,6 +917,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
sd3_use_xformers = os.getenv('SD3_USE_XFORMERS', '0')
if sd3_use_xformers == '1':
self.disable_xformers_memory_efficient_attention()
if output_type == "latent": if output_type == "latent":
image = latents image = latents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment