Commit 4bd4ad93 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add first attempt at scripting attention

parent dd8e44b3
......@@ -17,7 +17,7 @@ import math
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, Attention
from openfold.model.primitives import Linear, scripted_attention
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
......@@ -69,7 +69,7 @@ class MSAAttention(nn.Module):
self.c_z, self.no_heads, bias=False, init="normal"
)
self.mha = Attention(
self.mha = scripted_attention(
self.c_in, self.c_in, self.c_in,
self.c_hidden,
self.no_heads
......
......@@ -14,7 +14,7 @@
# limitations under the License.
import math
from typing import Optional, Callable
from typing import Optional, Callable, List
import numpy as np
import torch
......@@ -212,7 +212,7 @@ class Attention(nn.Module):
self.c_hidden * self.no_heads, self.c_q, init="final"
)
if(self.gating):
if(self.gating is not None):
self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating")
self.sigmoid = nn.Sigmoid()
......@@ -222,7 +222,7 @@ class Attention(nn.Module):
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
biases: bool = None,
biases: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
......@@ -235,20 +235,26 @@ class Attention(nn.Module):
Returns
[*, Q, C_q] attention update
"""
# Flatten batch dims
batch_dims = q_x.shape[:-2]
q_x = q_x.view((-1,) + q_x.shape[-2:])
k_x = k_x.view((-1,) + k_x.shape[-2:])
v_x = v_x.view((-1,) + v_x.shape[-2:])
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
v = self.linear_v(v_x)
# [*, Q/K, H, C_hidden]
q = q.view(*q.shape[:-1], self.no_heads, -1)
k = k.view(*k.shape[:-1], self.no_heads, -1)
v = v.view(*v.shape[:-1], self.no_heads, -1)
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
# [*, H, Q, K]
a = torch.matmul(
permute_final_dims(q, 1, 0, 2), # [*, H, Q, C_hidden]
permute_final_dims(k, 1, 2, 0), # [*, H, C_hidden, K]
q.permute(0, 2, 1, 3), # [*, H, Q, C_hidden]
k.permute(0, 2, 3, 1), # [*, H, C_hidden, K]
)
norm = 1 / math.sqrt(self.c_hidden) # [1]
a = a * norm
......@@ -260,7 +266,7 @@ class Attention(nn.Module):
# [*, H, Q, C_hidden]
o = torch.matmul(
a,
permute_final_dims(v, 1, 0, 2), # [*, H, V, C_hidden]
v.permute(0, 2, 1, 3), # [*, H, V, C_hidden]
)
# [*, Q, H, C_hidden]
......@@ -268,7 +274,7 @@ class Attention(nn.Module):
if(self.gating):
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(*g.shape[:-1], self.no_heads, -1)
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden]
......@@ -276,5 +282,11 @@ class Attention(nn.Module):
# [*, Q, C_q]
o = self.linear_o(o)
# Restore the batch dims
o = o.reshape(batch_dims + o.shape[1:])
return o
def scripted_attention(*args, **kwargs):
return torch.jit.script(Attention(*args, **kwargs))
......@@ -18,7 +18,7 @@ import math
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, Attention
from openfold.model.primitives import Linear, scripted_attention
from openfold.utils.deepspeed import checkpoint_blocks
from openfold.model.dropout import (
DropoutRowwise,
......@@ -69,7 +69,7 @@ class TemplatePointwiseAttention(nn.Module):
self.no_heads = no_heads
self.chunk_size = chunk_size
self.mha = Attention(
self.mha = scripted_attention(
self.c_z, self.c_t, self.c_t,
self.c_hidden, self.no_heads,
gating=False,
......
......@@ -18,7 +18,7 @@ import math
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, Attention
from openfold.model.primitives import Linear, scripted_attention
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
......@@ -57,7 +57,7 @@ class TriangleAttention(nn.Module):
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
self.mha = Attention(
self.mha = scripted_attention(
self.c_in, self.c_in, self.c_in,
self.c_hidden,
self.no_heads
......
......@@ -24,8 +24,8 @@ def permute_final_dims(tensor, *inds):
return tensor.permute(*first_inds, *[zero_index + i for i in inds])
def flatten_final_dims(tensor, no_dims):
return tensor.reshape(*tensor.shape[:-no_dims], -1)
def flatten_final_dims(tensor: torch.Tensor, no_dims: int):
return tensor.reshape(tensor.shape[:-no_dims] + (-1,))
def masked_mean(mask, value, dim, eps=1e-10):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment