Commit 3260242f authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add mechanism for TorchScript'ing primitives during inference

parent 1f87c6e5
......@@ -166,9 +166,8 @@ class Linear(nn.Linear):
class Attention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
initialization.
initialization. Allows multiple bias vectors.
"""
def __init__(
self,
c_q: int,
......
from typing import Optional, Sequence
import torch
import torch.nn as nn
def script_submodules_(
model: nn.Module,
types: Optional[Sequence[type]] = None,
):
"""
Convert all submodules whose types match one of those in the input
list to recursively scripted equivalents in place. To script the entire
model, just call torch.jit.script on it directly.
Args:
model: A torch.nn.Module
types: A list of types of submodules to script
"""
for name, child in model.named_children():
if(any(isinstance(child, t) for t in types)):
setattr(model, name, torch.jit.script(child))
else:
script_submodules_(child, types)
......@@ -31,22 +31,30 @@ import torch
from openfold.config import model_config
from openfold.data import templates, feature_pipeline, data_pipeline
from openfold.model.model import AlphaFold
from openfold.model.primitives import Attention, GlobalAttention
from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax
from openfold.utils.import_weights import (
import_jax_weights_,
)
from openfold.utils.torchscript_utils import script_submodules_
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
from scripts.utils import add_data_args
def script_primitives_(model):
script_submodules_(model, [Attention, GlobalAttention])
def main(args):
config = model_config(args.model_name)
model = AlphaFold(config)
model = model.eval()
import_jax_weights_(model, args.param_path)
script_primitives_(model)
model = model.to(args.model_device)
template_featurizer = templates.TemplateHitFeaturizer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment