Commit 8e4e71c8 authored by lijian6's avatar lijian6
Browse files

Add xformers fa for flux


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent 8a79d8ec
......@@ -14,7 +14,7 @@
import inspect
import math
from typing import Callable, List, Optional, Tuple, Union
import os
import torch
import torch.nn.functional as F
from torch import nn
......@@ -23,7 +23,7 @@ from ..image_processor import IPAdapterMaskProcessor
from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp, MemoryEfficientAttentionTritonFwdFlashBwOp
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -328,11 +328,12 @@ class Attention(nn.Module):
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
#_ = xformers.ops.memory_efficient_attention(
# torch.randn((1, 2, 40), device="cuda"),
# torch.randn((1, 2, 40), device="cuda"),
# torch.randn((1, 2, 40), device="cuda"),
#)
pass
except Exception as e:
raise e
......@@ -1732,33 +1733,48 @@ class FluxSingleAttnProcessor2_0:
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
flux_use_xformers = os.getenv('FLUX_USE_XFORMERS', '0')
if flux_use_xformers == '1':
q_seq_len = query.shape[1]
k_seq_len = key.shape[1]
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
if image_rotary_emb is not None:
query, key = apply_rope(query, key, image_rotary_emb)
query = query.transpose(1, 2).contiguous().view(batch_size, q_seq_len, -1)
key = key.transpose(1, 2).contiguous().view(batch_size, k_seq_len, -1)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, op=MemoryEfficientAttentionTritonFwdFlashBwOp
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
else:
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
......@@ -1801,6 +1817,60 @@ class FluxAttnProcessor2_0:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
flux_use_xformers = os.getenv('FLUX_USE_XFORMERS', '0')
if flux_use_xformers == '1':
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
q_seq_len = query.shape[2]
k_seq_len = key.shape[2]
v_seq_len = value.shape[2]
if image_rotary_emb is not None:
query, key = apply_rope(query, key, image_rotary_emb)
query = query.transpose(1, 2).contiguous().view(batch_size, q_seq_len, -1)
key = key.transpose(1, 2).contiguous().view(batch_size, k_seq_len, -1)
value = value.transpose(1, 2).contiguous().view(batch_size, v_seq_len, -1)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, op=MemoryEfficientAttentionTritonFwdFlashBwOp
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
else:
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
......@@ -1839,10 +1909,6 @@ class FluxAttnProcessor2_0:
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment