Unverified Commit 8669e831 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] feat: add lora attention processor for pt 2.0. (#3594)

* feat: add lora attention processor for pt 2.0.

* explicit context manager for SDPA.

* switch to flash attention

* make shapes compatible to work optimally with SDPA.

* fix: circular import problem.

* explicitly specify the flash attention kernel in sdpa

* fall back to efficient attention context manager.

* remove explicit dispatch.

* fix: removed processor.

* fix: remove optional from type annotation.

* feat: make changes regarding LoRAAttnProcessor2_0.

* remove confusing warning.

* formatting.

* relax tolerance for PT 2.0

* fix: loading message.

* remove unnecessary logging.

* add: entry to the docs.

* add: network_alpha argument.

* relax tolerance.
parent b45204ea
...@@ -11,6 +11,9 @@ An attention processor is a class for applying different types of attention mech ...@@ -11,6 +11,9 @@ An attention processor is a class for applying different types of attention mech
## LoRAAttnProcessor ## LoRAAttnProcessor
[[autodoc]] models.attention_processor.LoRAAttnProcessor [[autodoc]] models.attention_processor.LoRAAttnProcessor
## LoRAAttnProcessor2_0
[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0
## CustomDiffusionAttnProcessor ## CustomDiffusionAttnProcessor
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor [[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
......
...@@ -55,6 +55,7 @@ from diffusers.models.attention_processor import ( ...@@ -55,6 +55,7 @@ from diffusers.models.attention_processor import (
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
LoRAAttnAddedKVProcessor, LoRAAttnAddedKVProcessor,
LoRAAttnProcessor, LoRAAttnProcessor,
LoRAAttnProcessor2_0,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
...@@ -844,8 +845,9 @@ def main(args): ...@@ -844,8 +845,9 @@ def main(args):
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
lora_attn_processor_class = LoRAAttnAddedKVProcessor lora_attn_processor_class = LoRAAttnAddedKVProcessor
else: else:
lora_attn_processor_class = LoRAAttnProcessor lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
unet_lora_attn_procs[name] = lora_attn_processor_class( unet_lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
) )
......
...@@ -18,6 +18,7 @@ from pathlib import Path ...@@ -18,6 +18,7 @@ from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import torch import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from .models.attention_processor import ( from .models.attention_processor import (
...@@ -27,6 +28,7 @@ from .models.attention_processor import ( ...@@ -27,6 +28,7 @@ from .models.attention_processor import (
CustomDiffusionXFormersAttnProcessor, CustomDiffusionXFormersAttnProcessor,
LoRAAttnAddedKVProcessor, LoRAAttnAddedKVProcessor,
LoRAAttnProcessor, LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
...@@ -287,7 +289,9 @@ class UNet2DConditionLoadersMixin: ...@@ -287,7 +289,9 @@ class UNet2DConditionLoadersMixin:
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)): if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
attn_processor_class = LoRAXFormersAttnProcessor attn_processor_class = LoRAXFormersAttnProcessor
else: else:
attn_processor_class = LoRAAttnProcessor attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
attn_processors[key] = attn_processor_class( attn_processors[key] = attn_processor_class(
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -927,11 +931,11 @@ class LoraLoaderMixin: ...@@ -927,11 +931,11 @@ class LoraLoaderMixin:
# Load the layers corresponding to text encoder and make necessary adjustments. # Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
logger.info(f"Loading {self.text_encoder_name}.")
text_encoder_lora_state_dict = { text_encoder_lora_state_dict = {
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
} }
if len(text_encoder_lora_state_dict) > 0: if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {self.text_encoder_name}.")
attn_procs_text_encoder = self._load_text_encoder_attn_procs( attn_procs_text_encoder = self._load_text_encoder_attn_procs(
text_encoder_lora_state_dict, network_alpha=network_alpha text_encoder_lora_state_dict, network_alpha=network_alpha
) )
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
import torch import torch
...@@ -166,7 +165,8 @@ class Attention(nn.Module): ...@@ -166,7 +165,8 @@ class Attention(nn.Module):
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
): ):
is_lora = hasattr(self, "processor") and isinstance( is_lora = hasattr(self, "processor") and isinstance(
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor) self.processor,
(LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
) )
is_custom_diffusion = hasattr(self, "processor") and isinstance( is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
...@@ -200,14 +200,6 @@ class Attention(nn.Module): ...@@ -200,14 +200,6 @@ class Attention(nn.Module):
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU " " only available for GPU "
) )
elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
warnings.warn(
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
"We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) "
"introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall "
"back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 "
"native efficient flash attention."
)
else: else:
try: try:
# Make sure we can run the memory efficient attention # Make sure we can run the memory efficient attention
...@@ -220,6 +212,8 @@ class Attention(nn.Module): ...@@ -220,6 +212,8 @@ class Attention(nn.Module):
raise e raise e
if is_lora: if is_lora:
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
processor = LoRAXFormersAttnProcessor( processor = LoRAXFormersAttnProcessor(
hidden_size=self.processor.hidden_size, hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim, cross_attention_dim=self.processor.cross_attention_dim,
...@@ -252,7 +246,10 @@ class Attention(nn.Module): ...@@ -252,7 +246,10 @@ class Attention(nn.Module):
processor = XFormersAttnProcessor(attention_op=attention_op) processor = XFormersAttnProcessor(attention_op=attention_op)
else: else:
if is_lora: if is_lora:
processor = LoRAAttnProcessor( attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
processor = attn_processor_class(
hidden_size=self.processor.hidden_size, hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim, cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank, rank=self.processor.rank,
...@@ -548,6 +545,8 @@ class LoRAAttnProcessor(nn.Module): ...@@ -548,6 +545,8 @@ class LoRAAttnProcessor(nn.Module):
The number of channels in the `encoder_hidden_states`. The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4): rank (`int`, defaults to 4):
The dimension of the LoRA update matrices. The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
""" """
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
...@@ -843,6 +842,7 @@ class LoRAAttnAddedKVProcessor(nn.Module): ...@@ -843,6 +842,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
The number of channels in the `encoder_hidden_states`. The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4): rank (`int`, defaults to 4):
The dimension of the LoRA update matrices. The dimension of the LoRA update matrices.
""" """
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
...@@ -1162,6 +1162,9 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -1162,6 +1162,9 @@ class LoRAXFormersAttnProcessor(nn.Module):
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator. operator.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
""" """
def __init__( def __init__(
...@@ -1236,6 +1239,97 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -1236,6 +1239,97 @@ class LoRAXFormersAttnProcessor(nn.Module):
return hidden_states return hidden_states
class LoRAAttnProcessor2_0(nn.Module):
r"""
Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
attention.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
"""
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CustomDiffusionXFormersAttnProcessor(nn.Module): class CustomDiffusionXFormersAttnProcessor(nn.Module):
r""" r"""
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
...@@ -1520,6 +1614,7 @@ AttentionProcessor = Union[ ...@@ -1520,6 +1614,7 @@ AttentionProcessor = Union[
XFormersAttnAddedKVProcessor, XFormersAttnAddedKVProcessor,
LoRAAttnProcessor, LoRAAttnProcessor,
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
LoRAAttnAddedKVProcessor, LoRAAttnAddedKVProcessor,
CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor, CustomDiffusionXFormersAttnProcessor,
......
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
...@@ -28,6 +29,7 @@ from diffusers.models.attention_processor import ( ...@@ -28,6 +29,7 @@ from diffusers.models.attention_processor import (
AttnProcessor, AttnProcessor,
AttnProcessor2_0, AttnProcessor2_0,
LoRAAttnProcessor, LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
...@@ -46,16 +48,24 @@ def create_unet_lora_layers(unet: nn.Module): ...@@ -46,16 +48,24 @@ def create_unet_lora_layers(unet: nn.Module):
elif name.startswith("down_blocks"): elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id] hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet_lora_layers = AttnProcsLayers(lora_attn_procs) unet_lora_layers = AttnProcsLayers(lora_attn_procs)
return lora_attn_procs, unet_lora_layers return lora_attn_procs, unet_lora_layers
def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): def create_text_encoder_lora_attn_procs(text_encoder: nn.Module):
text_lora_attn_procs = {} text_lora_attn_procs = {}
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
for name, module in text_encoder.named_modules(): for name, module in text_encoder.named_modules():
if name.endswith(TEXT_ENCODER_ATTN_MODULE): if name.endswith(TEXT_ENCODER_ATTN_MODULE):
text_lora_attn_procs[name] = LoRAAttnProcessor( text_lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=module.out_proj.out_features, cross_attention_dim=None hidden_size=module.out_proj.out_features, cross_attention_dim=None
) )
return text_lora_attn_procs return text_lora_attn_procs
...@@ -368,7 +378,10 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -368,7 +378,10 @@ class LoraLoaderMixinTests(unittest.TestCase):
# check if lora attention processors are used # check if lora attention processors are used
for _, module in sd_pipe.unet.named_modules(): for _, module in sd_pipe.unet.named_modules():
if isinstance(module, Attention): if isinstance(module, Attention):
self.assertIsInstance(module.processor, LoRAAttnProcessor) attn_proc_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
self.assertIsInstance(module.processor, attn_proc_class)
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_lora_unet_attn_processors_with_xformers(self): def test_lora_unet_attn_processors_with_xformers(self):
......
...@@ -261,7 +261,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -261,7 +261,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample - new_sample).abs().max() < 1e-4 assert (sample - new_sample).abs().max() < 5e-4
# LoRA and no LoRA should NOT be the same # LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4 assert (sample - old_sample).abs().max() > 1e-4
...@@ -295,7 +295,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -295,7 +295,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample - new_sample).abs().max() < 1e-4 assert (sample - new_sample).abs().max() < 3e-4
# LoRA and no LoRA should NOT be the same # LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4 assert (sample - old_sample).abs().max() > 1e-4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment