Unverified Commit 05deb52d authored by Michael Pang's avatar Michael Pang Committed by GitHub
Browse files

Optimize causal mask using torch.where (#2715)

* Optimize causal mask using torch.where

Instead of multiplying by 1.0 float mask, use torch.where with a bool mask for increased performance.

* Maintain compatiblity with torch 1.0.0 - thanks for PR feedback

* Fix typo

* reformat line for CI
parent 0a4b1068
...@@ -104,7 +104,10 @@ class Attention(nn.Module): ...@@ -104,7 +104,10 @@ class Attention(nn.Module):
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0 assert n_state % config.n_head == 0
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) self.register_buffer(
"bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
...@@ -142,8 +145,8 @@ class Attention(nn.Module): ...@@ -142,8 +145,8 @@ class Attention(nn.Module):
if self.scale: if self.scale:
w = w / math.sqrt(v.size(-1)) w = w / math.sqrt(v.size(-1))
nd, ns = w.size(-2), w.size(-1) nd, ns = w.size(-2), w.size(-1)
b = self.bias[:, :, ns - nd : ns, :ns] mask = self.bias[:, :, ns - nd : ns, :ns]
w = w * b - 1e4 * (1 - b) w = torch.where(mask, w, self.masked_bias)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment