"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "6b3e52c9e820cbbdf04a113663e54d1e1ba9fbcd"
Commit 8d1119df authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Update TorchScript call in inference script

parent 326e08de
...@@ -31,13 +31,12 @@ import torch ...@@ -31,13 +31,12 @@ import torch
from openfold.config import model_config from openfold.config import model_config
from openfold.data import templates, feature_pipeline, data_pipeline from openfold.data import templates, feature_pipeline, data_pipeline
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.primitives import Attention, GlobalAttention from openfold.model.torchscript import script_primitives_
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax import openfold.np.relax.relax as relax
from openfold.utils.import_weights import ( from openfold.utils.import_weights import (
import_jax_weights_, import_jax_weights_,
) )
from openfold.utils.torchscript_utils import script_submodules_
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tensor_tree_map, tensor_tree_map,
) )
...@@ -45,10 +44,6 @@ from openfold.utils.tensor_utils import ( ...@@ -45,10 +44,6 @@ from openfold.utils.tensor_utils import (
from scripts.utils import add_data_args from scripts.utils import add_data_args
def script_primitives_(model):
script_submodules_(model, [Attention, GlobalAttention])
def main(args): def main(args):
config = model_config(args.model_name) config = model_config(args.model_name)
model = AlphaFold(config) model = AlphaFold(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment