Unverified Commit 14976500 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

fuse attention mask (#2111)

* fuse attention mask

* lint

* use 0 beta when no attention mask re: @Birch-san
parent 96af5bf7
......@@ -185,17 +185,23 @@ class CrossAttention(nn.Module):
query = query.float()
key = key.float()
if attention_mask is None:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
beta = 0
else:
baddbmm_input = attention_mask
beta = 1
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
baddbmm_input,
query,
key.transpose(-1, -2),
beta=0,
beta=beta,
alpha=self.scale,
)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
if self.upcast_softmax:
attention_scores = attention_scores.float()
......@@ -228,11 +234,12 @@ class CrossAttnProcessor:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment