Unverified Commit 14976500 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

fuse attention mask (#2111)

* fuse attention mask

* lint

* use 0 beta when no attention mask re: @Birch-san
parent 96af5bf7
...@@ -185,17 +185,23 @@ class CrossAttention(nn.Module): ...@@ -185,17 +185,23 @@ class CrossAttention(nn.Module):
query = query.float() query = query.float()
key = key.float() key = key.float()
if attention_mask is None:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
beta = 0
else:
baddbmm_input = attention_mask
beta = 1
attention_scores = torch.baddbmm( attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), baddbmm_input,
query, query,
key.transpose(-1, -2), key.transpose(-1, -2),
beta=0, beta=beta,
alpha=self.scale, alpha=self.scale,
) )
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
if self.upcast_softmax: if self.upcast_softmax:
attention_scores = attention_scores.float() attention_scores = attention_scores.float()
...@@ -228,11 +234,12 @@ class CrossAttnProcessor: ...@@ -228,11 +234,12 @@ class CrossAttnProcessor:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key) key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value) value = attn.head_to_batch_dim(value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment