Commit 4bd4ad93 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add first attempt at scripting attention

parent dd8e44b3
...@@ -17,7 +17,7 @@ import math ...@@ -17,7 +17,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, Attention from openfold.model.primitives import Linear, scripted_attention
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
...@@ -69,7 +69,7 @@ class MSAAttention(nn.Module): ...@@ -69,7 +69,7 @@ class MSAAttention(nn.Module):
self.c_z, self.no_heads, bias=False, init="normal" self.c_z, self.no_heads, bias=False, init="normal"
) )
self.mha = Attention( self.mha = scripted_attention(
self.c_in, self.c_in, self.c_in, self.c_in, self.c_in, self.c_in,
self.c_hidden, self.c_hidden,
self.no_heads self.no_heads
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, Callable from typing import Optional, Callable, List
import numpy as np import numpy as np
import torch import torch
...@@ -212,7 +212,7 @@ class Attention(nn.Module): ...@@ -212,7 +212,7 @@ class Attention(nn.Module):
self.c_hidden * self.no_heads, self.c_q, init="final" self.c_hidden * self.no_heads, self.c_q, init="final"
) )
if(self.gating): if(self.gating is not None):
self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating") self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating")
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
...@@ -222,7 +222,7 @@ class Attention(nn.Module): ...@@ -222,7 +222,7 @@ class Attention(nn.Module):
q_x: torch.Tensor, q_x: torch.Tensor,
k_x: torch.Tensor, k_x: torch.Tensor,
v_x: torch.Tensor, v_x: torch.Tensor,
biases: bool = None, biases: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -235,20 +235,26 @@ class Attention(nn.Module): ...@@ -235,20 +235,26 @@ class Attention(nn.Module):
Returns Returns
[*, Q, C_q] attention update [*, Q, C_q] attention update
""" """
# Flatten batch dims
batch_dims = q_x.shape[:-2]
q_x = q_x.view((-1,) + q_x.shape[-2:])
k_x = k_x.view((-1,) + k_x.shape[-2:])
v_x = v_x.view((-1,) + v_x.shape[-2:])
# [*, Q/K/V, H * C_hidden] # [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x) q = self.linear_q(q_x)
k = self.linear_k(k_x) k = self.linear_k(k_x)
v = self.linear_v(v_x) v = self.linear_v(v_x)
# [*, Q/K, H, C_hidden] # [*, Q/K, H, C_hidden]
q = q.view(*q.shape[:-1], self.no_heads, -1) q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(*k.shape[:-1], self.no_heads, -1) k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(*v.shape[:-1], self.no_heads, -1) v = v.view(v.shape[:-1] + (self.no_heads, -1))
# [*, H, Q, K] # [*, H, Q, K]
a = torch.matmul( a = torch.matmul(
permute_final_dims(q, 1, 0, 2), # [*, H, Q, C_hidden] q.permute(0, 2, 1, 3), # [*, H, Q, C_hidden]
permute_final_dims(k, 1, 2, 0), # [*, H, C_hidden, K] k.permute(0, 2, 3, 1), # [*, H, C_hidden, K]
) )
norm = 1 / math.sqrt(self.c_hidden) # [1] norm = 1 / math.sqrt(self.c_hidden) # [1]
a = a * norm a = a * norm
...@@ -260,7 +266,7 @@ class Attention(nn.Module): ...@@ -260,7 +266,7 @@ class Attention(nn.Module):
# [*, H, Q, C_hidden] # [*, H, Q, C_hidden]
o = torch.matmul( o = torch.matmul(
a, a,
permute_final_dims(v, 1, 0, 2), # [*, H, V, C_hidden] v.permute(0, 2, 1, 3), # [*, H, V, C_hidden]
) )
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
...@@ -268,7 +274,7 @@ class Attention(nn.Module): ...@@ -268,7 +274,7 @@ class Attention(nn.Module):
if(self.gating): if(self.gating):
g = self.sigmoid(self.linear_g(q_x)) g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
g = g.view(*g.shape[:-1], self.no_heads, -1) g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g o = o * g
# [*, Q, H * C_hidden] # [*, Q, H * C_hidden]
...@@ -276,5 +282,11 @@ class Attention(nn.Module): ...@@ -276,5 +282,11 @@ class Attention(nn.Module):
# [*, Q, C_q] # [*, Q, C_q]
o = self.linear_o(o) o = self.linear_o(o)
# Restore the batch dims
o = o.reshape(batch_dims + o.shape[1:])
return o return o
def scripted_attention(*args, **kwargs):
return torch.jit.script(Attention(*args, **kwargs))
...@@ -18,7 +18,7 @@ import math ...@@ -18,7 +18,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, Attention from openfold.model.primitives import Linear, scripted_attention
from openfold.utils.deepspeed import checkpoint_blocks from openfold.utils.deepspeed import checkpoint_blocks
from openfold.model.dropout import ( from openfold.model.dropout import (
DropoutRowwise, DropoutRowwise,
...@@ -69,7 +69,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -69,7 +69,7 @@ class TemplatePointwiseAttention(nn.Module):
self.no_heads = no_heads self.no_heads = no_heads
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.mha = Attention( self.mha = scripted_attention(
self.c_z, self.c_t, self.c_t, self.c_z, self.c_t, self.c_t,
self.c_hidden, self.no_heads, self.c_hidden, self.no_heads,
gating=False, gating=False,
......
...@@ -18,7 +18,7 @@ import math ...@@ -18,7 +18,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, Attention from openfold.model.primitives import Linear, scripted_attention
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
...@@ -57,7 +57,7 @@ class TriangleAttention(nn.Module): ...@@ -57,7 +57,7 @@ class TriangleAttention(nn.Module):
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
self.mha = Attention( self.mha = scripted_attention(
self.c_in, self.c_in, self.c_in, self.c_in, self.c_in, self.c_in,
self.c_hidden, self.c_hidden,
self.no_heads self.no_heads
......
...@@ -24,8 +24,8 @@ def permute_final_dims(tensor, *inds): ...@@ -24,8 +24,8 @@ def permute_final_dims(tensor, *inds):
return tensor.permute(*first_inds, *[zero_index + i for i in inds]) return tensor.permute(*first_inds, *[zero_index + i for i in inds])
def flatten_final_dims(tensor, no_dims): def flatten_final_dims(tensor: torch.Tensor, no_dims: int):
return tensor.reshape(*tensor.shape[:-no_dims], -1) return tensor.reshape(tensor.shape[:-no_dims] + (-1,))
def masked_mean(mask, value, dim, eps=1e-10): def masked_mean(mask, value, dim, eps=1e-10):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment