Commit 60d148bd authored by zhuww's avatar zhuww
Browse files

fix some problems and add save .pkl

parent 6a41c3e7
......@@ -4,7 +4,7 @@ import logging
import torch
_triton_available = True
_triton_available = False
if _triton_available:
try:
from .triton.softmax import softmax_triton_kernel_wrapper
......
......@@ -399,7 +399,6 @@ class AlphaFold(nn.Module):
outputs["single"] = s
# Predict 3D structure
z = [z]
outputs_sm = self.structure_module(
s,
z,
......
......@@ -787,7 +787,6 @@ class StructureModule(nn.Module):
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask)
del z
s = self.ipa_dropout(s)
torch.cuda.empty_cache()
s = self.layer_norm_ipa(s)
......
......@@ -21,6 +21,7 @@ import time
from datetime import date
import tempfile
import contextlib
import logging
import numpy as np
import torch
......@@ -43,6 +44,10 @@ from fastfold.data.parsers import parse_fasta
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
if int(torch.__version__.split(".")[0]) >= 1 and int(torch.__version__.split(".")[1]) > 11:
torch.backends.cuda.matmul.allow_tf32 = True
......@@ -449,6 +454,15 @@ def inference_monomer_model(args):
# f'{tag}_{args.model_name}_relaxed.pdb')
# with open(relaxed_output_path, 'w') as f:
# f.write(relaxed_pdb_str)
if(args.save_outputs):
output_dict_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_output_dict.pkl'
)
with open(output_dict_path, "wb") as fp:
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Model output written to {output_dict_path}...")
if __name__ == "__main__":
......@@ -483,6 +497,10 @@ if __name__ == "__main__":
help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from
./data/params""")
parser.add_argument(
"--save_outputs", action="store_true", default=False,
help="Whether to save all model outputs, including embeddings, etc."
)
parser.add_argument("--cpus",
type=int,
default=12,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment