Unverified Commit 4bbc51d9 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Attention processor] Better warning message when shifting to `AttnProcessor2_0` (#3457)

* add: debugging to enabling memory efficient processing

* add: better warning message.
parent f7b4f51c
......@@ -191,7 +191,10 @@ class Attention(nn.Module):
elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
warnings.warn(
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
"We will default to PyTorch's native efficient flash attention implementation provided by PyTorch 2.0."
"We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) "
"introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall "
"back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 "
"native efficient flash attention."
)
else:
try:
......@@ -213,6 +216,9 @@ class Attention(nn.Module):
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
print(
f"is_lora is set to {is_lora}, type: LoRAXFormersAttnProcessor: {isinstance(processor, LoRAXFormersAttnProcessor)}"
)
elif is_custom_diffusion:
processor = CustomDiffusionXFormersAttnProcessor(
train_kv=self.processor.train_kv,
......@@ -250,6 +256,7 @@ class Attention(nn.Module):
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
print("Still defaulting to: AttnProcessor2_0 :O")
processor = (
AttnProcessor2_0()
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment