Commit 8e4e71c8 authored by lijian6's avatar lijian6
Browse files

Add xformers fa for flux


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent 8a79d8ec
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import inspect import inspect
import math import math
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
...@@ -23,7 +23,7 @@ from ..image_processor import IPAdapterMaskProcessor ...@@ -23,7 +23,7 @@ from ..image_processor import IPAdapterMaskProcessor
from ..utils import deprecate, logging from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.import_utils import is_torch_npu_available, is_xformers_available
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp, MemoryEfficientAttentionTritonFwdFlashBwOp
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -328,11 +328,12 @@ class Attention(nn.Module): ...@@ -328,11 +328,12 @@ class Attention(nn.Module):
else: else:
try: try:
# Make sure we can run the memory efficient attention # Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention( #_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"), # torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"), # torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"), # torch.randn((1, 2, 40), device="cuda"),
) #)
pass
except Exception as e: except Exception as e:
raise e raise e
...@@ -1732,37 +1733,52 @@ class FluxSingleAttnProcessor2_0: ...@@ -1732,37 +1733,52 @@ class FluxSingleAttnProcessor2_0:
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1] flux_use_xformers = os.getenv('FLUX_USE_XFORMERS', '0')
head_dim = inner_dim // attn.heads if flux_use_xformers == '1':
q_seq_len = query.shape[1]
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) k_seq_len = key.shape[1]
inner_dim = key.shape[-1]
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) head_dim = inner_dim // attn.heads
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None: if attn.norm_q is not None:
query = attn.norm_q(query) query = attn.norm_q(query)
if attn.norm_k is not None: if attn.norm_k is not None:
key = attn.norm_k(key) key = attn.norm_k(key)
if image_rotary_emb is not None:
# Apply RoPE if needed query, key = apply_rope(query, key, image_rotary_emb)
if image_rotary_emb is not None: query = query.transpose(1, 2).contiguous().view(batch_size, q_seq_len, -1)
# YiYi to-do: update uising apply_rotary_emb key = key.transpose(1, 2).contiguous().view(batch_size, k_seq_len, -1)
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb) query = attn.head_to_batch_dim(query).contiguous()
# key = apply_rotary_emb(key, image_rotary_emb) key = attn.head_to_batch_dim(key).contiguous()
query, key = apply_rope(query, key, image_rotary_emb) value = attn.head_to_batch_dim(value).contiguous()
# the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = xformers.ops.memory_efficient_attention(
# TODO: add support for attn.scale when we move to Torch 2.1 query, key, value, op=MemoryEfficientAttentionTritonFwdFlashBwOp
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) )
hidden_states = hidden_states.to(query.dtype)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = attn.batch_to_head_dim(hidden_states)
hidden_states = hidden_states.to(query.dtype) else:
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
if image_rotary_emb is not None:
query, key = apply_rope(query, key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if input_ndim == 4: if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
...@@ -1801,53 +1817,103 @@ class FluxAttnProcessor2_0: ...@@ -1801,53 +1817,103 @@ class FluxAttnProcessor2_0:
key = attn.to_k(hidden_states) key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states) value = attn.to_v(hidden_states)
inner_dim = key.shape[-1] flux_use_xformers = os.getenv('FLUX_USE_XFORMERS', '0')
head_dim = inner_dim // attn.heads if flux_use_xformers == '1':
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) batch_size, -1, attn.heads, head_dim
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_q is not None: if attn.norm_added_q is not None:
query = attn.norm_q(query) encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_k is not None: if attn.norm_added_k is not None:
key = attn.norm_k(key) encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# `context` projections. # attention
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
q_seq_len = query.shape[2]
k_seq_len = key.shape[2]
v_seq_len = value.shape[2]
if image_rotary_emb is not None:
query, key = apply_rope(query, key, image_rotary_emb)
query = query.transpose(1, 2).contiguous().view(batch_size, q_seq_len, -1)
key = key.transpose(1, 2).contiguous().view(batch_size, k_seq_len, -1)
value = value.transpose(1, 2).contiguous().view(batch_size, v_seq_len, -1)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, op=MemoryEfficientAttentionTritonFwdFlashBwOp
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
else:
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
batch_size, -1, attn.heads, head_dim key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention if attn.norm_q is not None:
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) query = attn.norm_q(query)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) if attn.norm_k is not None:
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) key = attn.norm_k(key)
if image_rotary_emb is not None: # `context` projections.
# YiYi to-do: update uising apply_rotary_emb encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
# from ..embeddings import apply_rotary_emb encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
# query = apply_rotary_emb(query, image_rotary_emb) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) batch_size, -1, attn.heads, head_dim
hidden_states = hidden_states.to(query.dtype) ).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
query, key = apply_rope(query, key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states, hidden_states = ( encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, : encoder_hidden_states.shape[1]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment