• Birch-san's avatar
    Support for cross-attention bias / mask (#2634) · 64bf5d33
    Birch-san authored
    
    
    * Cross-attention masks
    
    prefer qualified symbol, fix accidental Optional
    
    prefer qualified symbol in AttentionProcessor
    
    prefer qualified symbol in embeddings.py
    
    qualified symbol in transformed_2d
    
    qualify FloatTensor in unet_2d_blocks
    
    move new transformer_2d params attention_mask, encoder_attention_mask to the end of the section which is assumed (e.g. by functions such as checkpoint()) to have a stable positional param interface. regard return_dict as a special-case which is assumed to be injected separately from positional params (e.g. by create_custom_forward()).
    
    move new encoder_attention_mask param to end of CrossAttn block interfaces and Unet2DCondition interface, to maintain positional param interface.
    
    regenerate modeling_text_unet.py
    
    remove unused import
    
    unet_2d_condition encoder_attention_mask docs
    Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
    
    versatile_diffusion/modeling_text_unet.py encoder_attention_mask docs
    Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
    
    transformer_2d encoder_attention_mask docs
    Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
    
    unet_2d_blocks.py: add parameter name comments
    Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
    
    revert description. bool-to-bias treatment happens in unet_2d_condition only.
    
    comment parameter names
    
    fix copies, style
    
    * encoder_attention_mask for SimpleCrossAttnDownBlock2D, SimpleCrossAttnUpBlock2D
    
    * encoder_attention_mask for UNetMidBlock2DSimpleCrossAttn
    
    * support attention_mask, encoder_attention_mask in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D, KAttentionBlock. fix binding of attention_mask, cross_attention_kwargs params in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D checkpoint invocations.
    
    * fix mistake made during merge conflict resolution
    
    * regenerate versatile_diffusion
    
    * pass time embedding into checkpointed attention invocation
    
    * always assume encoder_attention_mask is a mask (i.e. not a bias).
    
    * style, fix-copies
    
    * add tests for cross-attention masks
    
    * add test for padding of attention mask
    
    * explain mask's query_tokens dim. fix explanation about broadcasting over channels; we actually broadcast over query tokens
    
    * support both masks and biases in Transformer2DModel#forward. document behaviour
    
    * fix-copies
    
    * delete attention_mask docs on the basis I never tested self-attention masking myself. not comfortable explaining it, since I don't actually understand how a self-attn mask can work in its current form: the key length will be different in every ResBlock (we don't downsample the mask when we downsample the image).
    
    * review feedback: the standard Unet blocks shouldn't pass temb to attn (only to resnet). remove from KCrossAttnDownBlock2D,KCrossAttnUpBlock2D#forward.
    
    * remove encoder_attention_mask param from SimpleCrossAttn{Up,Down}Block2D,UNetMidBlock2DSimpleCrossAttn, and mask-choice in those blocks' #forward, on the basis that they only do one type of attention, so the consumer can pass whichever type of attention_mask is appropriate.
    
    * put attention mask padding back to how it was (since the SD use-case it enabled wasn't important, and it breaks the original unclip use-case). disable the test which was added.
    
    * fix-copies
    
    * style
    
    * fix-copies
    
    * put encoder_attention_mask param back into Simple block forward interfaces, to ensure consistency of forward interface.
    
    * restore passing of emb to KAttentionBlock#forward, on the basis that removal caused test failures. restore also the passing of emb to checkpointed calls to KAttentionBlock#forward.
    
    * make simple unet2d blocks use encoder_attention_mask, but only when attention_mask is None. this should fix UnCLIP compatibility.
    
    * fix copies
    64bf5d33
embeddings.py 15.4 KB