• Will Berman's avatar
    Attention processor cross attention norm group norm (#3021) · 98c5e5da
    Will Berman authored
    add group norm type to attention processor cross attention norm
    
    This lets the cross attention norm use both a group norm block and a
    layer norm block.
    
    The group norm operates along the channels dimension
    and requires input shape (batch size, channels, *) where as the layer norm with a single
    `normalized_shape` dimension only operates over the least significant
    dimension i.e. (*, channels).
    
    The channels we want to normalize are the hidden dimension of the encoder hidden states.
    
    By convention, the encoder hidden states are always passed as (batch size, sequence
    length, hidden states).
    
    This means the layer norm can operate on the tensor without modification, but the group
    norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length).
    
    All existing attention processors will have the same logic and we can
    consolidate it in a helper function `prepare_encoder_hidden_states`
    
    prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten
    
    move norm_cross defined check to outside norm_encoder_hidden_states
    
    add missing attn.norm_cross check
    98c5e5da
attention_processor.py 31.6 KB