Commit 0a4e78fc authored by lijian6's avatar lijian6
Browse files

Add flash attention for HunyuanDiT


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent 39aa3909
...@@ -299,11 +299,12 @@ class Attention(nn.Module): ...@@ -299,11 +299,12 @@ class Attention(nn.Module):
else: else:
try: try:
# Make sure we can run the memory efficient attention # Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention( # _ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"), # torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"), # torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"), # torch.randn((1, 2, 40), device="cuda"),
) # )
pass
except Exception as e: except Exception as e:
raise e raise e
...@@ -1333,9 +1334,11 @@ class XFormersAttnProcessor: ...@@ -1333,9 +1334,11 @@ class XFormersAttnProcessor:
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
*args, *args,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
from .embeddings import apply_rotary_emb
if len(args) > 0 or kwargs.get("scale", None) is not None: if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message) deprecate("scale", "1.0.0", deprecation_message)
...@@ -1379,6 +1382,25 @@ class XFormersAttnProcessor: ...@@ -1379,6 +1382,25 @@ class XFormersAttnProcessor:
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
### Add for HunYuanDit xformers ###
q_seq_len = query.shape[1]
k_seq_len = key.shape[1]
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
if image_rotary_emb is not None:
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
query = query.transpose(1, 2).contiguous().view(batch_size, q_seq_len, -1)
key = key.transpose(1, 2).contiguous().view(batch_size, k_seq_len, -1)
### End add ###
query = attn.head_to_batch_dim(query).contiguous() query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous() key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous() value = attn.head_to_batch_dim(value).contiguous()
......
...@@ -34,7 +34,7 @@ from ...utils import ( ...@@ -34,7 +34,7 @@ from ...utils import (
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
if is_torch_xla_available(): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
...@@ -814,6 +814,7 @@ class HunyuanDiTPipeline(DiffusionPipeline): ...@@ -814,6 +814,7 @@ class HunyuanDiTPipeline(DiffusionPipeline):
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
...@@ -877,6 +878,7 @@ class HunyuanDiTPipeline(DiffusionPipeline): ...@@ -877,6 +878,7 @@ class HunyuanDiTPipeline(DiffusionPipeline):
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
self.disable_xformers_memory_efficient_attention()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment