Unverified Commit ea39cd7e authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

Attn added kv processor torch 2.0 block (#3023)

add AttnAddedKVProcessor2_0 block
parent 98c5e5da
...@@ -255,11 +255,15 @@ class Attention(nn.Module): ...@@ -255,11 +255,15 @@ class Attention(nn.Module):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor return tensor
def head_to_batch_dim(self, tensor): def head_to_batch_dim(self, tensor, out_dim=3):
head_size = self.heads head_size = self.heads
batch_size, seq_len, dim = tensor.shape batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) tensor = tensor.permute(0, 2, 1, 3)
if out_dim == 3:
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor return tensor
def get_attention_scores(self, query, key, attention_mask=None): def get_attention_scores(self, query, key, attention_mask=None):
...@@ -293,7 +297,7 @@ class Attention(nn.Module): ...@@ -293,7 +297,7 @@ class Attention(nn.Module):
return attention_probs return attention_probs
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None): def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
if batch_size is None: if batch_size is None:
deprecate( deprecate(
"batch_size=None", "batch_size=None",
...@@ -320,8 +324,13 @@ class Attention(nn.Module): ...@@ -320,8 +324,13 @@ class Attention(nn.Module):
else: else:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if attention_mask.shape[0] < batch_size * head_size: if out_dim == 3:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0) if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
return attention_mask return attention_mask
def norm_encoder_hidden_states(self, encoder_hidden_states): def norm_encoder_hidden_states(self, encoder_hidden_states):
...@@ -499,6 +508,64 @@ class AttnAddedKVProcessor: ...@@ -499,6 +508,64 @@ class AttnAddedKVProcessor:
return hidden_states return hidden_states
class AttnAddedKVProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class XFormersAttnProcessor: class XFormersAttnProcessor:
def __init__(self, attention_op: Optional[Callable] = None): def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op self.attention_op = attention_op
...@@ -764,6 +831,7 @@ AttentionProcessor = Union[ ...@@ -764,6 +831,7 @@ AttentionProcessor = Union[
SlicedAttnProcessor, SlicedAttnProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
LoRAAttnProcessor, LoRAAttnProcessor,
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
] ]
...@@ -15,10 +15,11 @@ from typing import Optional ...@@ -15,10 +15,11 @@ from typing import Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from .attention import AdaGroupNorm, AttentionBlock from .attention import AdaGroupNorm, AttentionBlock
from .attention_processor import Attention, AttnAddedKVProcessor from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from .dual_transformer_2d import DualTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
from .transformer_2d import Transformer2DModel from .transformer_2d import Transformer2DModel
...@@ -612,6 +613,10 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -612,6 +613,10 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
attentions = [] attentions = []
for _ in range(num_layers): for _ in range(num_layers):
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
attentions.append( attentions.append(
Attention( Attention(
query_dim=in_channels, query_dim=in_channels,
...@@ -624,7 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -624,7 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
upcast_softmax=True, upcast_softmax=True,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(), processor=processor,
) )
) )
resnets.append( resnets.append(
...@@ -1396,6 +1401,11 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1396,6 +1401,11 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
skip_time_act=skip_time_act, skip_time_act=skip_time_act,
) )
) )
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
attentions.append( attentions.append(
Attention( Attention(
query_dim=out_channels, query_dim=out_channels,
...@@ -1408,7 +1418,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1408,7 +1418,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
upcast_softmax=True, upcast_softmax=True,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(), processor=processor,
) )
) )
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
...@@ -2399,6 +2409,11 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2399,6 +2409,11 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
skip_time_act=skip_time_act, skip_time_act=skip_time_act,
) )
) )
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
attentions.append( attentions.append(
Attention( Attention(
query_dim=out_channels, query_dim=out_channels,
...@@ -2411,7 +2426,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2411,7 +2426,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
upcast_softmax=True, upcast_softmax=True,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(), processor=processor,
) )
) )
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
......
...@@ -8,7 +8,12 @@ import torch.nn.functional as F ...@@ -8,7 +8,12 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin from ...models import ModelMixin
from ...models.attention import Attention from ...models.attention import Attention
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor, AttnProcessor from ...models.attention_processor import (
AttentionProcessor,
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
AttnProcessor,
)
from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel from ...models.transformer_2d import Transformer2DModel
...@@ -1545,6 +1550,10 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1545,6 +1550,10 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
attentions = [] attentions = []
for _ in range(num_layers): for _ in range(num_layers):
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
attentions.append( attentions.append(
Attention( Attention(
query_dim=in_channels, query_dim=in_channels,
...@@ -1557,7 +1566,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1557,7 +1566,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
upcast_softmax=True, upcast_softmax=True,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(), processor=processor,
) )
) )
resnets.append( resnets.append(
......
...@@ -421,7 +421,12 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa ...@@ -421,7 +421,12 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu" test_max_difference = torch_device == "cpu"
self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference) # Check is relaxed because there is not a torch 2.0 sliced attention added kv processor
expected_max_diff = 1e-2
self._test_attention_slicing_forward_pass(
test_max_difference=test_max_difference, expected_max_diff=expected_max_diff
)
# Overriding PipelineTesterMixin::test_inference_batch_single_identical # Overriding PipelineTesterMixin::test_inference_batch_single_identical
# because UnCLIP undeterminism requires a looser check. # because UnCLIP undeterminism requires a looser check.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment