Unverified Commit 64bf5d33 authored by Birch-san's avatar Birch-san Committed by GitHub
Browse files

Support for cross-attention bias / mask (#2634)



* Cross-attention masks

prefer qualified symbol, fix accidental Optional

prefer qualified symbol in AttentionProcessor

prefer qualified symbol in embeddings.py

qualified symbol in transformed_2d

qualify FloatTensor in unet_2d_blocks

move new transformer_2d params attention_mask, encoder_attention_mask to the end of the section which is assumed (e.g. by functions such as checkpoint()) to have a stable positional param interface. regard return_dict as a special-case which is assumed to be injected separately from positional params (e.g. by create_custom_forward()).

move new encoder_attention_mask param to end of CrossAttn block interfaces and Unet2DCondition interface, to maintain positional param interface.

regenerate modeling_text_unet.py

remove unused import

unet_2d_condition encoder_attention_mask docs
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

versatile_diffusion/modeling_text_unet.py encoder_attention_mask docs
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

transformer_2d encoder_attention_mask docs
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

unet_2d_blocks.py: add parameter name comments
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

revert description. bool-to-bias treatment happens in unet_2d_condition only.

comment parameter names

fix copies, style

* encoder_attention_mask for SimpleCrossAttnDownBlock2D, SimpleCrossAttnUpBlock2D

* encoder_attention_mask for UNetMidBlock2DSimpleCrossAttn

* support attention_mask, encoder_attention_mask in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D, KAttentionBlock. fix binding of attention_mask, cross_attention_kwargs params in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D checkpoint invocations.

* fix mistake made during merge conflict resolution

* regenerate versatile_diffusion

* pass time embedding into checkpointed attention invocation

* always assume encoder_attention_mask is a mask (i.e. not a bias).

* style, fix-copies

* add tests for cross-attention masks

* add test for padding of attention mask

* explain mask's query_tokens dim. fix explanation about broadcasting over channels; we actually broadcast over query tokens

* support both masks and biases in Transformer2DModel#forward. document behaviour

* fix-copies

* delete attention_mask docs on the basis I never tested self-attention masking myself. not comfortable explaining it, since I don't actually understand how a self-attn mask can work in its current form: the key length will be different in every ResBlock (we don't downsample the mask when we downsample the image).

* review feedback: the standard Unet blocks shouldn't pass temb to attn (only to resnet). remove from KCrossAttnDownBlock2D,KCrossAttnUpBlock2D#forward.

* remove encoder_attention_mask param from SimpleCrossAttn{Up,Down}Block2D,UNetMidBlock2DSimpleCrossAttn, and mask-choice in those blocks' #forward, on the basis that they only do one type of attention, so the consumer can pass whichever type of attention_mask is appropriate.

* put attention mask padding back to how it was (since the SD use-case it enabled wasn't important, and it breaks the original unclip use-case). disable the test which was added.

* fix-copies

* style

* fix-copies

* put encoder_attention_mask param back into Simple block forward interfaces, to ensure consistency of forward interface.

* restore passing of emb to KAttentionBlock#forward, on the basis that removal caused test failures. restore also the passing of emb to checkpointed calls to KAttentionBlock#forward.

* make simple unet2d blocks use encoder_attention_mask, but only when attention_mask is None. this should fix UnCLIP compatibility.

* fix copies
parent c4359d63
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Any, Dict, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -120,13 +120,13 @@ class BasicTransformerBlock(nn.Module): ...@@ -120,13 +120,13 @@ class BasicTransformerBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep=None, timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs=None, cross_attention_kwargs: Dict[str, Any] = None,
class_labels=None, class_labels: Optional[torch.LongTensor] = None,
): ):
# Notice that normalization is always applied before the real computation in the following blocks. # Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention # 1. Self-Attention
...@@ -155,8 +155,6 @@ class BasicTransformerBlock(nn.Module): ...@@ -155,8 +155,6 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states = ( norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
) )
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
# prepare attention mask here
attn_output = self.attn2( attn_output = self.attn2(
norm_hidden_states, norm_hidden_states,
......
...@@ -380,7 +380,13 @@ class Attention(nn.Module): ...@@ -380,7 +380,13 @@ class Attention(nn.Module):
if attention_mask is None: if attention_mask is None:
return attention_mask return attention_mask
if attention_mask.shape[-1] != target_length: current_length: int = attention_mask.shape[-1]
if current_length > target_length:
# we *could* trim the mask with:
# attention_mask = attention_mask[:,:target_length]
# but this is weird enough that it's more likely to be a mistake than a shortcut
raise ValueError(f"mask's length ({current_length}) exceeds the sequence length ({target_length}).")
elif current_length < target_length:
if attention_mask.device.type == "mps": if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor. # HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor. # Instead, we can manually construct the padding tensor.
...@@ -388,6 +394,10 @@ class Attention(nn.Module): ...@@ -388,6 +394,10 @@ class Attention(nn.Module):
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([attention_mask, padding], dim=2) attention_mask = torch.cat([attention_mask, padding], dim=2)
else: else:
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
# we want to instead pad by (0, remaining_length), where remaining_length is:
# remaining_length: int = target_length - current_length
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if out_dim == 3: if out_dim == 3:
...@@ -820,7 +830,13 @@ class XFormersAttnProcessor: ...@@ -820,7 +830,13 @@ class XFormersAttnProcessor:
def __init__(self, attention_op: Optional[Callable] = None): def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op self.attention_op = attention_op
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
):
residual = hidden_states residual = hidden_states
input_ndim = hidden_states.ndim input_ndim = hidden_states.ndim
...@@ -829,11 +845,20 @@ class XFormersAttnProcessor: ...@@ -829,11 +845,20 @@ class XFormersAttnProcessor:
batch_size, channel, height, width = hidden_states.shape batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = ( batch_size, key_tokens, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
) )
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
......
...@@ -352,7 +352,7 @@ class LabelEmbedding(nn.Module): ...@@ -352,7 +352,7 @@ class LabelEmbedding(nn.Module):
labels = torch.where(drop_ids, self.num_classes, labels) labels = torch.where(drop_ids, self.num_classes, labels)
return labels return labels
def forward(self, labels, force_drop_ids=None): def forward(self, labels: torch.LongTensor, force_drop_ids=None):
use_dropout = self.dropout_prob > 0 use_dropout = self.dropout_prob > 0
if (self.training and use_dropout) or (force_drop_ids is not None): if (self.training and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids) labels = self.token_drop(labels, force_drop_ids)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Any, Dict, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -213,11 +213,13 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -213,11 +213,13 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
timestep=None, timestep: Optional[torch.LongTensor] = None,
class_labels=None, class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs=None, cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
): ):
""" """
...@@ -228,11 +230,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -228,11 +230,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention. self-attention.
timestep ( `torch.long`, *optional*): timestep ( `torch.LongTensor`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
conditioning. conditioning.
encoder_attention_mask ( `torch.Tensor`, *optional* ).
Cross-attention mask, applied to encoder_hidden_states. Two formats supported:
Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0
= keep, -10000 = discard.
If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
...@@ -241,6 +249,29 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -241,6 +249,29 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor. returning a tuple, the first element is the sample tensor.
""" """
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input # 1. Input
if self.is_input_continuous: if self.is_input_continuous:
batch, _, height, width = hidden_states.shape batch, _, height, width = hidden_states.shape
...@@ -264,7 +295,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -264,7 +295,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
for block in self.transformer_blocks: for block in self.transformer_blocks:
hidden_states = block( hidden_states = block(
hidden_states, hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep, timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels, class_labels=class_labels,
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Any, Dict, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -558,14 +558,22 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -558,14 +558,22 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward( def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None self,
): hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -659,16 +667,34 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -659,16 +667,34 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward( def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
mask = None if encoder_hidden_states is None else encoder_attention_mask
else:
# when attention_mask is defined: we don't even check for encoder_attention_mask.
# this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
# TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
# then we can simplify this whole if/else block to:
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
mask = attention_mask
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
# attn # attn
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=mask,
**cross_attention_kwargs, **cross_attention_kwargs,
) )
...@@ -850,9 +876,14 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -850,9 +876,14 @@ class CrossAttnDownBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ):
# TODO(Patrick, William) - attention mask is not used
output_states = () output_states = ()
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
...@@ -867,26 +898,23 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -867,26 +898,23 @@ class CrossAttnDownBlock2D(nn.Module):
return custom_forward return custom_forward
if is_torch_version(">=", "1.11.0"): ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False create_custom_forward(resnet),
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, temb,
cross_attention_kwargs, **ckpt_kwargs,
use_reentrant=False,
)[0]
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
) )
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs, cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -894,6 +922,8 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -894,6 +922,8 @@ class CrossAttnDownBlock2D(nn.Module):
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
...@@ -1501,11 +1531,28 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1501,11 +1531,28 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ):
output_states = () output_states = ()
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
mask = None if encoder_hidden_states is None else encoder_attention_mask
else:
# when attention_mask is defined: we don't even check for encoder_attention_mask.
# this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
# TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
# then we can simplify this whole if/else block to:
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
mask = attention_mask
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
...@@ -1523,6 +1570,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1523,6 +1570,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
create_custom_forward(attn, return_dict=False), create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
mask,
cross_attention_kwargs, cross_attention_kwargs,
)[0] )[0]
else: else:
...@@ -1531,7 +1579,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1531,7 +1579,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=mask,
**cross_attention_kwargs, **cross_attention_kwargs,
) )
...@@ -1690,7 +1738,13 @@ class KCrossAttnDownBlock2D(nn.Module): ...@@ -1690,7 +1738,13 @@ class KCrossAttnDownBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ):
output_states = () output_states = ()
...@@ -1706,28 +1760,22 @@ class KCrossAttnDownBlock2D(nn.Module): ...@@ -1706,28 +1760,22 @@ class KCrossAttnDownBlock2D(nn.Module):
return custom_forward return custom_forward
if is_torch_version(">=", "1.11.0"): ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False create_custom_forward(resnet),
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, temb,
attention_mask, **ckpt_kwargs,
cross_attention_kwargs,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
) )
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
temb,
attention_mask, attention_mask,
cross_attention_kwargs, cross_attention_kwargs,
encoder_attention_mask,
**ckpt_kwargs,
) )
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1737,6 +1785,7 @@ class KCrossAttnDownBlock2D(nn.Module): ...@@ -1737,6 +1785,7 @@ class KCrossAttnDownBlock2D(nn.Module):
emb=temb, emb=temb,
attention_mask=attention_mask, attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
) )
if self.downsamplers is None: if self.downsamplers is None:
...@@ -1916,15 +1965,15 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1916,15 +1965,15 @@ class CrossAttnUpBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
res_hidden_states_tuple, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb=None, temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs=None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size=None, upsample_size: Optional[int] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ):
# TODO(Patrick, William) - attention mask is not used
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states # pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
...@@ -1942,26 +1991,23 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1942,26 +1991,23 @@ class CrossAttnUpBlock2D(nn.Module):
return custom_forward return custom_forward
if is_torch_version(">=", "1.11.0"): ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), create_custom_forward(resnet),
hidden_states, hidden_states,
encoder_hidden_states, temb,
cross_attention_kwargs, **ckpt_kwargs,
use_reentrant=False,
)[0]
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
) )
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs, cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1969,6 +2015,8 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1969,6 +2015,8 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
...@@ -2594,15 +2642,28 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2594,15 +2642,28 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
res_hidden_states_tuple, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb=None, temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
upsample_size=None, upsample_size: Optional[int] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs=None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
mask = None if encoder_hidden_states is None else encoder_attention_mask
else:
# when attention_mask is defined: we don't even check for encoder_attention_mask.
# this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
# TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
# then we can simplify this whole if/else block to:
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
mask = attention_mask
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
# resnet # resnet
# pop res hidden states # pop res hidden states
...@@ -2626,6 +2687,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2626,6 +2687,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
create_custom_forward(attn, return_dict=False), create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
mask,
cross_attention_kwargs, cross_attention_kwargs,
)[0] )[0]
else: else:
...@@ -2634,7 +2696,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2634,7 +2696,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=mask,
**cross_attention_kwargs, **cross_attention_kwargs,
) )
...@@ -2811,13 +2873,14 @@ class KCrossAttnUpBlock2D(nn.Module): ...@@ -2811,13 +2873,14 @@ class KCrossAttnUpBlock2D(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
res_hidden_states_tuple, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb=None, temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs=None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size=None, upsample_size: Optional[int] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ):
res_hidden_states_tuple = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[-1]
if res_hidden_states_tuple is not None: if res_hidden_states_tuple is not None:
...@@ -2835,28 +2898,22 @@ class KCrossAttnUpBlock2D(nn.Module): ...@@ -2835,28 +2898,22 @@ class KCrossAttnUpBlock2D(nn.Module):
return custom_forward return custom_forward
if is_torch_version(">=", "1.11.0"): ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False create_custom_forward(resnet),
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, temb,
attention_mask, **ckpt_kwargs,
cross_attention_kwargs,
use_reentrant=False,
)[0]
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
) )
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
temb,
attention_mask, attention_mask,
cross_attention_kwargs, cross_attention_kwargs,
encoder_attention_mask,
**ckpt_kwargs,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -2866,6 +2923,7 @@ class KCrossAttnUpBlock2D(nn.Module): ...@@ -2866,6 +2923,7 @@ class KCrossAttnUpBlock2D(nn.Module):
emb=temb, emb=temb,
attention_mask=attention_mask, attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
) )
if self.upsamplers is not None: if self.upsamplers is not None:
...@@ -2944,11 +3002,14 @@ class KAttentionBlock(nn.Module): ...@@ -2944,11 +3002,14 @@ class KAttentionBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
emb=None, # TODO: mark emb as non-optional (self.norm2 requires it).
attention_mask=None, # requires assessing impact of change to positional param interface.
cross_attention_kwargs=None, emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
...@@ -2962,6 +3023,7 @@ class KAttentionBlock(nn.Module): ...@@ -2962,6 +3023,7 @@ class KAttentionBlock(nn.Module):
attn_output = self.attn1( attn_output = self.attn1(
norm_hidden_states, norm_hidden_states,
encoder_hidden_states=None, encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs, **cross_attention_kwargs,
) )
attn_output = self._to_4d(attn_output, height, weight) attn_output = self._to_4d(attn_output, height, weight)
...@@ -2976,6 +3038,7 @@ class KAttentionBlock(nn.Module): ...@@ -2976,6 +3038,7 @@ class KAttentionBlock(nn.Module):
attn_output = self.attn2( attn_output = self.attn2(
norm_hidden_states, norm_hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask,
**cross_attention_kwargs, **cross_attention_kwargs,
) )
attn_output = self._to_4d(attn_output, height, weight) attn_output = self._to_4d(attn_output, height, weight)
......
...@@ -618,6 +618,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -618,6 +618,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None, mid_block_additional_residual: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]: ) -> Union[UNet2DConditionOutput, Tuple]:
r""" r"""
...@@ -625,6 +626,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -625,6 +626,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
encoder_attention_mask (`torch.Tensor`):
(batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False =
discard. Mask will be converted into a bias, which adds large negative values to attention scores
corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
...@@ -651,11 +656,27 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -651,11 +656,27 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
logger.info("Forward upsample size to force interpolation output size.") logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True forward_upsample_size = True
# prepare attention_mask # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None: if attention_mask is not None:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 0. center input if necessary # 0. center input if necessary
if self.config.center_input_sample: if self.config.center_input_sample:
sample = 2 * sample - 1.0 sample = 2 * sample - 1.0
...@@ -727,6 +748,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -727,6 +748,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
) )
else: else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
...@@ -752,6 +774,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -752,6 +774,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
) )
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
...@@ -778,6 +801,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -778,6 +801,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size, upsample_size=upsample_size,
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
) )
else: else:
sample = upsample_block( sample = upsample_block(
......
...@@ -721,6 +721,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -721,6 +721,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None, mid_block_additional_residual: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]: ) -> Union[UNet2DConditionOutput, Tuple]:
r""" r"""
...@@ -728,6 +729,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -728,6 +729,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
encoder_attention_mask (`torch.Tensor`):
(batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False =
discard. Mask will be converted into a bias, which adds large negative values to attention scores
corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
...@@ -754,11 +759,27 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -754,11 +759,27 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
logger.info("Forward upsample size to force interpolation output size.") logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True forward_upsample_size = True
# prepare attention_mask # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None: if attention_mask is not None:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 0. center input if necessary # 0. center input if necessary
if self.config.center_input_sample: if self.config.center_input_sample:
sample = 2 * sample - 1.0 sample = 2 * sample - 1.0
...@@ -830,6 +851,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -830,6 +851,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
) )
else: else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
...@@ -855,6 +877,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -855,6 +877,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
) )
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
...@@ -881,6 +904,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -881,6 +904,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size, upsample_size=upsample_size,
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
) )
else: else:
sample = upsample_block( sample = upsample_block(
...@@ -1188,9 +1212,14 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1188,9 +1212,14 @@ class CrossAttnDownBlockFlat(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ):
# TODO(Patrick, William) - attention mask is not used
output_states = () output_states = ()
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
...@@ -1205,26 +1234,23 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1205,26 +1234,23 @@ class CrossAttnDownBlockFlat(nn.Module):
return custom_forward return custom_forward
if is_torch_version(">=", "1.11.0"): ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), create_custom_forward(resnet),
hidden_states, hidden_states,
encoder_hidden_states, temb,
cross_attention_kwargs, **ckpt_kwargs,
use_reentrant=False,
)[0]
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
) )
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs, cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1232,6 +1258,8 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1232,6 +1258,8 @@ class CrossAttnDownBlockFlat(nn.Module):
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
...@@ -1414,15 +1442,15 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1414,15 +1442,15 @@ class CrossAttnUpBlockFlat(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
res_hidden_states_tuple, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb=None, temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs=None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size=None, upsample_size: Optional[int] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ):
# TODO(Patrick, William) - attention mask is not used
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states # pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
...@@ -1440,26 +1468,23 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1440,26 +1468,23 @@ class CrossAttnUpBlockFlat(nn.Module):
return custom_forward return custom_forward
if is_torch_version(">=", "1.11.0"): ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False create_custom_forward(resnet),
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, temb,
cross_attention_kwargs, **ckpt_kwargs,
use_reentrant=False,
)[0]
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
) )
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs, cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1467,6 +1492,8 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1467,6 +1492,8 @@ class CrossAttnUpBlockFlat(nn.Module):
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
...@@ -1564,14 +1591,22 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -1564,14 +1591,22 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward( def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None self,
): hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1666,16 +1701,34 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1666,16 +1701,34 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward( def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
mask = None if encoder_hidden_states is None else encoder_attention_mask
else:
# when attention_mask is defined: we don't even check for encoder_attention_mask.
# this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
# TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
# then we can simplify this whole if/else block to:
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
mask = attention_mask
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
# attn # attn
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=mask,
**cross_attention_kwargs, **cross_attention_kwargs,
) )
......
...@@ -20,6 +20,7 @@ import unittest ...@@ -20,6 +20,7 @@ import unittest
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from pytest import mark
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, LoRAAttnProcessor from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, LoRAAttnProcessor
...@@ -418,6 +419,76 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -418,6 +419,76 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
assert processor.is_run assert processor.is_run
assert processor.number == 123 assert processor.number == 123
@parameterized.expand(
[
# fmt: off
[torch.bool],
[torch.long],
[torch.float],
# fmt: on
]
)
def test_model_xattn_mask(self, mask_dtype):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
model.to(torch_device)
model.eval()
cond = inputs_dict["encoder_hidden_states"]
with torch.no_grad():
full_cond_out = model(**inputs_dict).sample
assert full_cond_out is not None
keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype)
full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample
assert full_cond_keepallmask_out.allclose(
full_cond_out
), "a 'keep all' mask should give the same result as no mask"
trunc_cond = cond[:, :-1, :]
trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample
assert not trunc_cond_out.allclose(
full_cond_out
), "discarding the last token from our cond should change the result"
batch, tokens, _ = cond.shape
mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype)
masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample
assert masked_cond_out.allclose(
trunc_cond_out
), "masking the last token from our cond should be equivalent to truncating that token out of the condition"
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric.
# maybe it's fine that this only works for the unclip use-case.
@mark.skip(
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
)
def test_model_xattn_padding(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
model.to(torch_device)
model.eval()
cond = inputs_dict["encoder_hidden_states"]
with torch.no_grad():
full_cond_out = model(**inputs_dict).sample
assert full_cond_out is not None
batch, tokens, _ = cond.shape
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
assert trunc_mask_out.allclose(
keeplast_out
), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
def test_lora_processors(self): def test_lora_processors(self):
# enable deterministic behavior for gradient checkpointing # enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment