Commit 34e9363c authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Make more components TorchScript-able, add tracing

parent 34e4e6ce
...@@ -71,11 +71,25 @@ class MSATransition(nn.Module): ...@@ -71,11 +71,25 @@ class MSATransition(nn.Module):
m = self.linear_2(m) * mask m = self.linear_2(m) * mask
return m return m
@torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
return chunk_layer(
self._transition,
{"m": m, "mask": mask},
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
def forward( def forward(
self, self,
m: torch.Tensor, m: torch.Tensor,
mask: torch.Tensor = None, mask: Optional[torch.Tensor] = None,
chunk_size: int = None, chunk_size: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -95,16 +109,10 @@ class MSATransition(nn.Module): ...@@ -95,16 +109,10 @@ class MSATransition(nn.Module):
m = self.layer_norm(m) m = self.layer_norm(m)
inp = {"m": m, "mask": mask}
if chunk_size is not None: if chunk_size is not None:
m = chunk_layer( m = self._chunk(m, mask, chunk_size)
self._transition,
inp,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
else: else:
m = self._transition(**inp) m = self._transition(m, mask)
return m return m
...@@ -201,9 +209,11 @@ class EvoformerBlock(nn.Module): ...@@ -201,9 +209,11 @@ class EvoformerBlock(nn.Module):
z: torch.Tensor, z: torch.Tensor,
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int, chunk_size: Optional[int] = None,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
#print(torch.cuda.memory_summary())
# DeepMind doesn't mask these transitions in the source, so _mask_trans # DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of # should be disabled to better approximate the exact activations of
# the original. # the original.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional from typing import Optional, List
from openfold.model.primitives import Linear, Attention, GlobalAttention from openfold.model.primitives import Linear, Attention, GlobalAttention
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
...@@ -63,6 +63,8 @@ class MSAAttention(nn.Module): ...@@ -63,6 +63,8 @@ class MSAAttention(nn.Module):
self.layer_norm_m = nn.LayerNorm(self.c_in) self.layer_norm_m = nn.LayerNorm(self.c_in)
self.layer_norm_z = None
self.linear_z = None
if self.pair_bias: if self.pair_bias:
self.layer_norm_z = nn.LayerNorm(self.c_z) self.layer_norm_z = nn.LayerNorm(self.c_z)
self.linear_z = Linear( self.linear_z = Linear(
...@@ -73,7 +75,25 @@ class MSAAttention(nn.Module): ...@@ -73,7 +75,25 @@ class MSAAttention(nn.Module):
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
) )
def forward(self, m, chunk_size, z=None, mask=None): @torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
) -> torch.Tensor:
return chunk_layer(
self.mha,
{"q_x": m, "k_x": m, "v_x": m, "biases": biases},
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
def forward(self,
m: torch.Tensor,
z: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
""" """
Args: Args:
m: m:
...@@ -83,6 +103,11 @@ class MSAAttention(nn.Module): ...@@ -83,6 +103,11 @@ class MSAAttention(nn.Module):
pair_bias is True pair_bias is True
mask: mask:
[*, N_seq, N_res] MSA mask [*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
""" """
# [*, N_seq, N_res, C_m] # [*, N_seq, N_res, C_m]
m = self.layer_norm_m(m) m = self.layer_norm_m(m)
...@@ -106,7 +131,11 @@ class MSAAttention(nn.Module): ...@@ -106,7 +131,11 @@ class MSAAttention(nn.Module):
biases = [bias] biases = [bias]
if self.pair_bias: if (self.pair_bias and
z is not None and # For the
self.layer_norm_z is not None and # benefit of
self.linear_z is not None # TorchScript
):
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
z = self.layer_norm_z(z) z = self.layer_norm_z(z)
...@@ -118,16 +147,10 @@ class MSAAttention(nn.Module): ...@@ -118,16 +147,10 @@ class MSAAttention(nn.Module):
biases.append(z) biases.append(z)
mha_inputs = {"q_x": m, "k_x": m, "v_x": m, "biases": biases}
if chunk_size is not None: if chunk_size is not None:
m = chunk_layer( m = self._chunk(m, biases, chunk_size)
self.mha,
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
else: else:
m = self.mha(**mha_inputs) m = self.mha(q_x=m, k_x=m, v_x=m, biases=biases)
return m return m
...@@ -161,9 +184,12 @@ class MSARowAttentionWithPairBias(MSAAttention): ...@@ -161,9 +184,12 @@ class MSARowAttentionWithPairBias(MSAAttention):
) )
class MSAColumnAttention(MSAAttention): class MSAColumnAttention(nn.Module):
""" """
Implements Algorithm 8. Implements Algorithm 8.
By rights, this should also be a subclass of MSAAttention. Alas,
most inheritance isn't supported by TorchScript.
""" """
def __init__(self, c_m, c_hidden, no_heads, inf=1e9): def __init__(self, c_m, c_hidden, no_heads, inf=1e9):
...@@ -178,7 +204,14 @@ class MSAColumnAttention(MSAAttention): ...@@ -178,7 +204,14 @@ class MSAColumnAttention(MSAAttention):
inf: inf:
Large number used to construct attention masks Large number used to construct attention masks
""" """
super(MSAColumnAttention, self).__init__( super(MSAColumnAttention, self).__init__()
self.c_m = c_m
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self._msa_att = MSAAttention(
c_in=c_m, c_in=c_m,
c_hidden=c_hidden, c_hidden=c_hidden,
no_heads=no_heads, no_heads=no_heads,
...@@ -187,31 +220,40 @@ class MSAColumnAttention(MSAAttention): ...@@ -187,31 +220,40 @@ class MSAColumnAttention(MSAAttention):
inf=inf, inf=inf,
) )
def forward(self, m, chunk_size, mask=None): def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
) -> torch.Tensor:
""" """
Args: Args:
m: m:
[*, N_seq, N_res, C_m] MSA embedding [*, N_seq, N_res, C_m] MSA embedding
mask: mask:
[*, N_seq, N_res] MSA mask [*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
""" """
# [*, N_res, N_seq, C_in] # [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
if mask is not None: if mask is not None:
mask = mask.transpose(-1, -2) mask = mask.transpose(-1, -2)
m = super().forward(m, chunk_size=chunk_size, mask=mask) m = self._msa_att(m, mask=mask, chunk_size=chunk_size)
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
if mask is not None: if mask is not None:
mask = mask.transpose(-1, -2) mask = mask.transpose(-1, -2)
return m return m
class MSAColumnGlobalAttention(nn.Module): class MSAColumnGlobalAttention(nn.Module):
def __init__( def __init__(
self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10 self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10,
): ):
super(MSAColumnGlobalAttention, self).__init__() super(MSAColumnGlobalAttention, self).__init__()
...@@ -231,8 +273,28 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -231,8 +273,28 @@ class MSAColumnGlobalAttention(nn.Module):
eps=eps, eps=eps,
) )
@torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
mha_input = {
"m": m,
"mask": mask,
}
return chunk_layer(
self.global_attention,
mha_input,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
def forward( def forward(
self, m: torch.Tensor, chunk_size, mask: Optional[torch.Tensor] = None self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:] n_seq, n_res, c_in = m.shape[-3:]
...@@ -251,19 +313,10 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -251,19 +313,10 @@ class MSAColumnGlobalAttention(nn.Module):
# [*, N_res, N_seq, C_in] # [*, N_res, N_seq, C_in]
m = self.layer_norm_m(m) m = self.layer_norm_m(m)
mha_input = {
"m": m,
"mask": mask,
}
if chunk_size is not None: if chunk_size is not None:
m = chunk_layer( m = self._chunk(m, mask, chunk_size)
self.global_attention,
mha_input,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
else: else:
m = self.global_attention(m=mha_input["m"], mask=mha_input["mask"]) m = self.global_attention(m=m, mask=mask)
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
from functools import partial from functools import partial
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -38,6 +40,7 @@ class OuterProductMean(nn.Module): ...@@ -38,6 +40,7 @@ class OuterProductMean(nn.Module):
""" """
super(OuterProductMean, self).__init__() super(OuterProductMean, self).__init__()
self.c_m = c_m
self.c_z = c_z self.c_z = c_z
self.c_hidden = c_hidden self.c_hidden = c_hidden
self.eps = eps self.eps = eps
...@@ -52,14 +55,43 @@ class OuterProductMean(nn.Module): ...@@ -52,14 +55,43 @@ class OuterProductMean(nn.Module):
outer = torch.einsum("...bac,...dae->...bdce", a, b) outer = torch.einsum("...bac,...dae->...bdce", a, b)
# [*, N_res, N_res, C * C] # [*, N_res, N_res, C * C]
outer = outer.reshape(*outer.shape[:-2], -1) outer = outer.reshape(outer.shape[:-2] + (-1,))
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
outer = self.linear_out(outer) outer = self.linear_out(outer)
return outer return outer
def forward(self, m, chunk_size, mask=None): @torch.jit.ignore
def _chunk(self,
a: torch.Tensor,
b: torch.Tensor,
chunk_size: int
) -> torch.Tensor:
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
a_reshape = a.reshape((-1,) + a.shape[-3:])
b_reshape = b.reshape((-1,) + b.shape[-3:])
out = []
for a_prime, b_prime in zip(a_reshape, b_reshape):
outer = chunk_layer(
partial(self._opm, b=b_prime),
{"a": a_prime},
chunk_size=chunk_size,
no_batch_dims=1,
)
out.append(outer)
outer = torch.stack(out, dim=0)
outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
return outer
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
) -> torch.Tensor:
""" """
Args: Args:
m: m:
...@@ -84,22 +116,7 @@ class OuterProductMean(nn.Module): ...@@ -84,22 +116,7 @@ class OuterProductMean(nn.Module):
b = b.transpose(-2, -3) b = b.transpose(-2, -3)
if chunk_size is not None: if chunk_size is not None:
# Since the "batch dim" in this case is not a true batch dimension outer = self._chunk(a, b, chunk_size)
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
a_reshape = a.reshape(-1, *a.shape[-3:])
b_reshape = b.reshape(-1, *b.shape[-3:])
out = []
for a_prime, b_prime in zip(a_reshape, b_reshape):
outer = chunk_layer(
partial(self._opm, b=b_prime),
{"a": a_prime},
chunk_size=chunk_size,
no_batch_dims=1,
)
out.append(outer)
outer = torch.stack(out, dim=0)
outer = outer.reshape(*a.shape[:-3], *outer.shape[1:])
else: else:
outer = self._opm(a, b) outer = self._opm(a, b)
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -54,7 +55,25 @@ class PairTransition(nn.Module): ...@@ -54,7 +55,25 @@ class PairTransition(nn.Module):
return z return z
def forward(self, z, chunk_size, mask=None): @torch.jit.ignore
def _chunk(self,
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
return chunk_layer(
self._transition,
{"z": z, "mask": mask},
chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]),
)
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
""" """
Args: Args:
z: z:
...@@ -72,15 +91,9 @@ class PairTransition(nn.Module): ...@@ -72,15 +91,9 @@ class PairTransition(nn.Module):
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
z = self.layer_norm(z) z = self.layer_norm(z)
inp = {"z": z, "mask": mask}
if chunk_size is not None: if chunk_size is not None:
z = chunk_layer( z = self._chunk(z, mask, chunk_size)
self._transition,
inp,
chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]),
)
else: else:
z = self._transition(**inp) z = self._transition(z=z, mask=mask)
return z return z
...@@ -155,17 +155,16 @@ class InvariantPointAttention(nn.Module): ...@@ -155,17 +155,16 @@ class InvariantPointAttention(nn.Module):
""" """
Implements Algorithm 22. Implements Algorithm 22.
""" """
def __init__( def __init__(
self, self,
c_s, c_s: int,
c_z, c_z: int,
c_hidden, c_hidden: int,
no_heads, no_heads: int,
no_qk_points, no_qk_points: int,
no_v_points, no_v_points: int,
inf=1e5, inf: float = 1e5,
eps=1e-8, eps: float = 1e-8,
): ):
""" """
Args: Args:
......
...@@ -12,9 +12,10 @@ ...@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial from functools import partial
import math import math
from typing import Optional, List
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -71,7 +72,32 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -71,7 +72,32 @@ class TemplatePointwiseAttention(nn.Module):
gating=False, gating=False,
) )
def forward(self, t, z, chunk_size, template_mask=None): def _chunk(self,
z: torch.Tensor,
t: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
) -> torch.Tensor:
mha_inputs = {
"q_x": z,
"k_x": t,
"v_x": t,
"biases": biases,
}
return chunk_layer(
self.mha,
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]),
)
def forward(self,
t: torch.Tensor,
z: torch.Tensor,
template_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
) -> torch.Tensor:
""" """
Args: Args:
t: t:
...@@ -95,21 +121,11 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -95,21 +121,11 @@ class TemplatePointwiseAttention(nn.Module):
t = permute_final_dims(t, (1, 2, 0, 3)) t = permute_final_dims(t, (1, 2, 0, 3))
# [*, N_res, N_res, 1, C_z] # [*, N_res, N_res, 1, C_z]
mha_inputs = { biases = [bias]
"q_x": z,
"k_x": t,
"v_x": t,
"biases": [bias],
}
if chunk_size is not None: if chunk_size is not None:
z = chunk_layer( z = self._chunk(z, t, biases, chunk_size)
self.mha,
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]),
)
else: else:
z = self.mha(**mha_inputs) z = self.mha(q_x=z, k_x=t, v_x=t, biases=biases)
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
z = z.squeeze(-2) z = z.squeeze(-2)
...@@ -120,13 +136,13 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -120,13 +136,13 @@ class TemplatePointwiseAttention(nn.Module):
class TemplatePairStackBlock(nn.Module): class TemplatePairStackBlock(nn.Module):
def __init__( def __init__(
self, self,
c_t, c_t: int,
c_hidden_tri_att, c_hidden_tri_att: int,
c_hidden_tri_mul, c_hidden_tri_mul: int,
no_heads, no_heads: int,
pair_transition_n, pair_transition_n: int,
dropout_rate, dropout_rate: float,
inf, inf: float,
**kwargs, **kwargs,
): ):
super(TemplatePairStackBlock, self).__init__() super(TemplatePairStackBlock, self).__init__()
...@@ -169,7 +185,12 @@ class TemplatePairStackBlock(nn.Module): ...@@ -169,7 +185,12 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n, self.pair_transition_n,
) )
def forward(self, z, mask, chunk_size, _mask_trans=True): def forward(self,
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True
):
single_templates = [ single_templates = [
t.unsqueeze(-4) for t in torch.unbind(z, dim=-4) t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
] ]
...@@ -208,8 +229,8 @@ class TemplatePairStackBlock(nn.Module): ...@@ -208,8 +229,8 @@ class TemplatePairStackBlock(nn.Module):
) )
single = single + self.pair_transition( single = single + self.pair_transition(
single, single,
mask=single_mask if _mask_trans else None,
chunk_size=chunk_size, chunk_size=chunk_size,
mask=single_mask if _mask_trans else None
) )
single_templates[i] = single single_templates[i] = single
......
from typing import Optional, Sequence # Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Sequence, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.dropout import (
DropoutRowwise,
DropoutColumnwise,
)
from openfold.model.evoformer import (
EvoformerBlock,
EvoformerStack,
)
from openfold.model.outer_product_mean import OuterProductMean
from openfold.model.msa import (
MSARowAttentionWithPairBias,
MSAColumnAttention,
MSAColumnGlobalAttention,
)
from openfold.model.pair_transition import PairTransition
from openfold.model.primitives import Attention, GlobalAttention from openfold.model.primitives import Attention, GlobalAttention
from openfold.model.structure_module import (
InvariantPointAttention,
BackboneUpdate,
)
from openfold.model.template import TemplatePairStackBlock
from openfold.model.triangular_attention import (
TriangleAttentionStartingNode,
TriangleAttentionEndingNode,
)
from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
)
def script_preset_(model: torch.nn.Module):
"""
TorchScript a handful of low-level but frequently used submodule types
that are known to be scriptable.
Args:
model:
A torch.nn.Module. It should contain at least some modules from
this repository, or this function won't do anything.
"""
script_submodules_(
model,
[
nn.Dropout,
Attention,
GlobalAttention,
EvoformerBlock,
#TemplatePairStackBlock,
],
attempt_trace=False,
batch_dims=None,
)
def _get_module_device(module: torch.nn.Module) -> torch.device:
"""
Fetches the device of a module, assuming that all of the module's
parameters reside on a single device
Args:
module: A torch.nn.Module
Returns:
The module's device
"""
return next(module.parameters()).device
def _trace_module(module, batch_dims=None):
if(batch_dims is None):
batch_dims = ()
# Stand-in values
n_seq = 10
n_res = 10
device = _get_module_device(module)
def msa(channel_dim):
return torch.rand(
(*batch_dims, n_seq, n_res, channel_dim),
device=device,
)
def pair(channel_dim):
return torch.rand(
(*batch_dims, n_res, n_res, channel_dim),
device=device,
)
if(isinstance(module, MSARowAttentionWithPairBias)):
inputs = {
"forward": (
msa(module.c_in), # m
pair(module.c_z), # z
torch.randint(
0, 2,
(*batch_dims, n_seq, n_res)
), # mask
),
}
elif(isinstance(module, MSAColumnAttention)):
inputs = {
"forward": (
msa(module.c_in), # m
torch.randint(
0, 2,
(*batch_dims, n_seq, n_res)
), # mask
),
}
elif(isinstance(module, OuterProductMean)):
inputs = {
"forward": (
msa(module.c_m),
torch.randint(
0, 2,
(*batch_dims, n_seq, n_res)
)
)
}
module = OPM(module)
else:
raise TypeError(
f"tracing is not supported for modules of type {type(module)}"
)
return torch.jit.trace_module(module, inputs)
def _script_submodules_helper_(
model,
types,
attempt_trace,
to_trace,
):
for name, child in model.named_children():
if(types is None or any(isinstance(child, t) for t in types)):
try:
scripted = torch.jit.script(child)
setattr(model, name, scripted)
continue
except (RuntimeError, torch.jit.frontend.NotSupportedError) as e:
if(attempt_trace):
to_trace.add(type(child))
else:
raise e
_script_submodules_helper_(child, types, attempt_trace, to_trace)
def _trace_submodules_(
model,
types,
batch_dims=None,
):
for name, child in model.named_children():
if(any(isinstance(child, t) for t in types)):
traced = _trace_module(child, batch_dims=batch_dims)
setattr(model, name, traced)
else:
_trace_submodules_(child, types, batch_dims=batch_dims)
def script_primitives_(model):
script_submodules_(model, [Attention, GlobalAttention])
def script_submodules_( def script_submodules_(
model: nn.Module, model: nn.Module,
types: Optional[Sequence[type]] = None, types: Optional[Sequence[type]] = None,
attempt_trace: Optional[bool] = True,
batch_dims: Optional[Tuple[int]] = None,
): ):
""" """
Convert all submodules whose types match one of those in the input Convert all submodules whose types match one of those in the input
list to recursively scripted equivalents in place. To script the entire list to recursively scripted equivalents in place. To script the entire
model, just call torch.jit.script on it directly. model, just call torch.jit.script on it directly.
When types is None, all submodules are scripted. When types is None, all submodules are scripted.
Args: Args:
model: A torch.nn.Module model:
types: A list of types of submodules to script A torch.nn.Module
types:
A list of types of submodules to script
attempt_trace:
Whether to attempt to trace specified modules if scripting
fails. Recall that tracing eliminates all conditional
logic---with great tracing comes the mild responsibility of
having to remember to ensure that the modules in question
perform the same computations no matter what.
""" """
for name, child in model.named_children(): to_trace = set()
if(types is None or any(isinstance(child, t) for t in types)):
setattr(model, name, torch.jit.script(child)) # Aggressively script as much as possible first...
else: _script_submodules_helper_(model, types, attempt_trace, to_trace)
script_submodules_(child, types)
# ... and then trace stragglers.
if(attempt_trace and len(to_trace) > 0):
_trace_submodules_(model, to_trace, batch_dims=batch_dims)
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
from functools import partialmethod from functools import partialmethod
import math import math
from typing import Optional, List
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -55,7 +57,30 @@ class TriangleAttention(nn.Module): ...@@ -55,7 +57,30 @@ class TriangleAttention(nn.Module):
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
) )
def forward(self, x, chunk_size, mask=None): @torch.jit.ignore
def _chunk(self,
x: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
) -> torch.Tensor:
mha_inputs = {
"q_x": x,
"k_x": x,
"v_x": x,
"biases": biases,
}
return chunk_layer(
self.mha,
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]),
)
def forward(self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
) -> torch.Tensor:
""" """
Args: Args:
x: x:
...@@ -86,21 +111,12 @@ class TriangleAttention(nn.Module): ...@@ -86,21 +111,12 @@ class TriangleAttention(nn.Module):
# [*, 1, H, I, J] # [*, 1, H, I, J]
triangle_bias = triangle_bias.unsqueeze(-4) triangle_bias = triangle_bias.unsqueeze(-4)
mha_inputs = { biases = [mask_bias, triangle_bias]
"q_x": x,
"k_x": x,
"v_x": x,
"biases": [mask_bias, triangle_bias],
}
if chunk_size is not None: if chunk_size is not None:
x = chunk_layer( x = self._chunk(x, biases, chunk_size)
self.mha,
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]),
)
else: else:
x = self.mha(**mha_inputs) x = self.mha(q_x=x, k_x=x, v_x=x, biases=biases)
if not self.starting: if not self.starting:
x = x.transpose(-2, -3) x = x.transpose(-2, -3)
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
from functools import partialmethod from functools import partialmethod
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -25,7 +27,6 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -25,7 +27,6 @@ class TriangleMultiplicativeUpdate(nn.Module):
""" """
Implements Algorithms 11 and 12. Implements Algorithms 11 and 12.
""" """
def __init__(self, c_z, c_hidden, _outgoing=True): def __init__(self, c_z, c_hidden, _outgoing=True):
""" """
Args: Args:
...@@ -51,39 +52,16 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -51,39 +52,16 @@ class TriangleMultiplicativeUpdate(nn.Module):
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
cp = self._outgoing_matmul if self._outgoing else self._incoming_matmul def _combine_projections(
self.combine_projections = cp a: torch.Tensor,
b: torch.Tensor,
def _outgoing_matmul( ) -> torch.Tensor:
self, raise NotImplementedError("This method needs to be overridden")
a: torch.Tensor, # [*, N_i, N_k, C]
b: torch.Tensor, # [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, (2, 0, 1)),
permute_final_dims(b, (2, 1, 0)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, (1, 2, 0))
def _incoming_matmul(
self,
a: torch.Tensor, # [*, N_k, N_i, C]
b: torch.Tensor, # [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, (2, 1, 0)),
permute_final_dims(b, (2, 0, 1)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, (1, 2, 0))
def forward(self, z, mask=None): def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
""" """
Args: Args:
x: x:
...@@ -103,7 +81,7 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -103,7 +81,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
a = a * mask a = a * mask
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z)) b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
b = b * mask b = b * mask
x = self.combine_projections(a, b) x = self._combine_projections(a, b)
x = self.layer_norm_out(x) x = self.layer_norm_out(x)
x = self.linear_z(x) x = self.linear_z(x)
g = self.sigmoid(self.linear_g(z)) g = self.sigmoid(self.linear_g(z))
...@@ -116,19 +94,36 @@ class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate): ...@@ -116,19 +94,36 @@ class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
""" """
Implements Algorithm 11. Implements Algorithm 11.
""" """
def _combine_projections(
self,
a: torch.Tensor, # [*, N_i, N_k, C]
b: torch.Tensor, # [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, (2, 0, 1)),
permute_final_dims(b, (2, 1, 0)),
)
__init__ = partialmethod( # [*, N_i, N_j, C]
TriangleMultiplicativeUpdate.__init__, return permute_final_dims(p, (1, 2, 0))
_outgoing=True,
)
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
""" """
Implements Algorithm 12. Implements Algorithm 12.
""" """
def _combine_projections(
self,
a: torch.Tensor, # [*, N_k, N_i, C]
b: torch.Tensor, # [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, (2, 1, 0)),
permute_final_dims(b, (2, 0, 1)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, (1, 2, 0))
__init__ = partialmethod(
TriangleMultiplicativeUpdate.__init__,
_outgoing=False,
)
...@@ -631,7 +631,7 @@ _qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)} ...@@ -631,7 +631,7 @@ _qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
def _to_mat(pairs): def _to_mat(pairs):
mat = torch.zeros((4, 4)) mat = np.zeros((4, 4))
for pair in pairs: for pair in pairs:
key, value = pair key, value = pair
ind = _qtr_ind_dict[key] ind = _qtr_ind_dict[key]
......
...@@ -17,10 +17,11 @@ import torch ...@@ -17,10 +17,11 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from typing import Any, Tuple, List, Callable from typing import Any, Tuple, List, Callable
BLOCK_ARG = Any BLOCK_ARG = Any
BLOCK_ARGS = List[BLOCK_ARG] BLOCK_ARGS = List[BLOCK_ARG]
@torch.jit.ignore
def checkpoint_blocks( def checkpoint_blocks(
blocks: List[Callable], blocks: List[Callable],
args: BLOCK_ARGS, args: BLOCK_ARGS,
......
...@@ -217,6 +217,11 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -217,6 +217,11 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"attention": AttentionGatedParams(matt.mha), "attention": AttentionGatedParams(matt.mha),
} }
MSAColAttParams = lambda matt: {
"query_norm": LayerNormParams(matt._msa_att.layer_norm_m),
"attention": AttentionGatedParams(matt._msa_att.mha),
}
MSAGlobalAttParams = lambda matt: { MSAGlobalAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m), "query_norm": LayerNormParams(matt.layer_norm_m),
"attention": GlobalAttentionParams(matt.global_attention), "attention": GlobalAttentionParams(matt.global_attention),
...@@ -270,7 +275,7 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -270,7 +275,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
msa_col_att_params = MSAGlobalAttParams(b.msa_att_col) msa_col_att_params = MSAGlobalAttParams(b.msa_att_col)
else: else:
col_att_name = "msa_column_attention" col_att_name = "msa_column_attention"
msa_col_att_params = MSAAttParams(b.msa_att_col) msa_col_att_params = MSAColAttParams(b.msa_att_col)
d = { d = {
"msa_row_attention_with_pair_bias": MSAAttPairBiasParams( "msa_row_attention_with_pair_bias": MSAAttPairBiasParams(
......
...@@ -107,6 +107,22 @@ def tree_map(fn, tree, leaf_type): ...@@ -107,6 +107,22 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
def _fetch_dims(tree):
shapes = []
tree_type = type(tree)
if tree_type is dict:
for v in tree.values():
shapes.extend(_fetch_dims(v))
elif tree_type is list or tree_type is tuple:
for t in tree:
shapes.extend(_fetch_dims(t))
elif tree_type is torch.Tensor:
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
return shapes
def chunk_layer( def chunk_layer(
layer: Callable, layer: Callable,
...@@ -141,33 +157,17 @@ def chunk_layer( ...@@ -141,33 +157,17 @@ def chunk_layer(
if not (len(inputs) > 0): if not (len(inputs) > 0):
raise ValueError("Must provide at least one input") raise ValueError("Must provide at least one input")
def fetch_dims(tree): initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
shapes = []
tree_type = type(tree)
if tree_type is dict:
for v in tree.values():
shapes.extend(fetch_dims(v))
elif tree_type is list or tree_type is tuple:
for t in tree:
shapes.extend(fetch_dims(t))
elif tree_type is torch.Tensor:
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
return shapes
initial_dims = [shape[:no_batch_dims] for shape in fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def prep_inputs(t): def _prep_inputs(t):
# TODO: make this more memory efficient. This sucks # TODO: make this more memory efficient. This sucks
if not sum(t.shape[:no_batch_dims]) == no_batch_dims: if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(*orig_batch_dims, *t.shape[no_batch_dims:]) t = t.expand(*orig_batch_dims, *t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:]) t = t.reshape(-1, *t.shape[no_batch_dims:])
return t return t
flattened_inputs = tensor_tree_map(prep_inputs, inputs) flattened_inputs = tensor_tree_map(_prep_inputs, inputs)
flat_batch_dim = 1 flat_batch_dim = 1
for d in orig_batch_dims: for d in orig_batch_dims:
......
...@@ -31,7 +31,7 @@ import torch ...@@ -31,7 +31,7 @@ import torch
from openfold.config import model_config from openfold.config import model_config
from openfold.data import templates, feature_pipeline, data_pipeline from openfold.data import templates, feature_pipeline, data_pipeline
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_primitives_ from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax import openfold.np.relax.relax as relax
from openfold.utils.import_weights import ( from openfold.utils.import_weights import (
...@@ -49,9 +49,9 @@ def main(args): ...@@ -49,9 +49,9 @@ def main(args):
model = AlphaFold(config) model = AlphaFold(config)
model = model.eval() model = model.eval()
import_jax_weights_(model, args.param_path) import_jax_weights_(model, args.param_path)
script_primitives_(model) script_preset_(model)
model = model.to(args.model_device) model = model.to(args.model_device)
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
......
...@@ -2,7 +2,7 @@ import argparse ...@@ -2,7 +2,7 @@ import argparse
import logging import logging
import os import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "6" os.environ["CUDA_VISIBLE_DEVICES"] = "6"
#os.environ["MASTER_ADDR"]="10.119.81.14" #os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069" #os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0" #os.environ["NODE_RANK"]="0"
...@@ -23,6 +23,7 @@ from openfold.data.data_modules import ( ...@@ -23,6 +23,7 @@ from openfold.data.data_modules import (
DummyDataLoader, DummyDataLoader,
) )
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.utils.callbacks import ( from openfold.utils.callbacks import (
EarlyStoppingVerbose, EarlyStoppingVerbose,
) )
...@@ -64,10 +65,6 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -64,10 +65,6 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss # Compute loss
loss = self.loss(outputs, batch) loss = self.loss(outputs, batch)
#if(torch.isnan(loss) or torch.isinf(loss)):
# logging.warning("loss is NaN. Skipping example...")
# loss = loss.new_tensor(0., requires_grad=True)
return {"loss": loss} return {"loss": loss}
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
...@@ -121,6 +118,10 @@ def main(args): ...@@ -121,6 +118,10 @@ def main(args):
sd = {k[len("module."):]:v for k,v in sd.items()} sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd) model_module.load_state_dict(sd)
logging.info("Successfully loaded model weights...") logging.info("Successfully loaded model weights...")
# TorchScript components of the model
script_preset_(model_module)
#data_module = DummyDataLoader("batch.pickle") #data_module = DummyDataLoader("batch.pickle")
data_module = OpenFoldDataModule( data_module = OpenFoldDataModule(
config=config.data, config=config.data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment