Commit 60d148bd authored by zhuww's avatar zhuww
Browse files

fix some problems and add save .pkl

parent 6a41c3e7
...@@ -4,7 +4,7 @@ import logging ...@@ -4,7 +4,7 @@ import logging
import torch import torch
_triton_available = True _triton_available = False
if _triton_available: if _triton_available:
try: try:
from .triton.softmax import softmax_triton_kernel_wrapper from .triton.softmax import softmax_triton_kernel_wrapper
......
...@@ -399,7 +399,6 @@ class AlphaFold(nn.Module): ...@@ -399,7 +399,6 @@ class AlphaFold(nn.Module):
outputs["single"] = s outputs["single"] = s
# Predict 3D structure # Predict 3D structure
z = [z]
outputs_sm = self.structure_module( outputs_sm = self.structure_module(
s, s,
z, z,
......
...@@ -787,7 +787,6 @@ class StructureModule(nn.Module): ...@@ -787,7 +787,6 @@ class StructureModule(nn.Module):
for i in range(self.no_blocks): for i in range(self.no_blocks):
# [*, N, C_s] # [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask) s = s + self.ipa(s, z, rigids, mask)
del z
s = self.ipa_dropout(s) s = self.ipa_dropout(s)
torch.cuda.empty_cache() torch.cuda.empty_cache()
s = self.layer_norm_ipa(s) s = self.layer_norm_ipa(s)
......
...@@ -21,6 +21,7 @@ import time ...@@ -21,6 +21,7 @@ import time
from datetime import date from datetime import date
import tempfile import tempfile
import contextlib import contextlib
import logging
import numpy as np import numpy as np
import torch import torch
...@@ -43,6 +44,10 @@ from fastfold.data.parsers import parse_fasta ...@@ -43,6 +44,10 @@ from fastfold.data.parsers import parse_fasta
from fastfold.utils.import_weights import import_jax_weights_ from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map from fastfold.utils.tensor_utils import tensor_tree_map
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
if int(torch.__version__.split(".")[0]) >= 1 and int(torch.__version__.split(".")[1]) > 11: if int(torch.__version__.split(".")[0]) >= 1 and int(torch.__version__.split(".")[1]) > 11:
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
...@@ -450,6 +455,15 @@ def inference_monomer_model(args): ...@@ -450,6 +455,15 @@ def inference_monomer_model(args):
# with open(relaxed_output_path, 'w') as f: # with open(relaxed_output_path, 'w') as f:
# f.write(relaxed_pdb_str) # f.write(relaxed_pdb_str)
if(args.save_outputs):
output_dict_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_output_dict.pkl'
)
with open(output_dict_path, "wb") as fp:
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Model output written to {output_dict_path}...")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -483,6 +497,10 @@ if __name__ == "__main__": ...@@ -483,6 +497,10 @@ if __name__ == "__main__":
help="""Path to model parameters. If None, parameters are selected help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from automatically according to the model name from
./data/params""") ./data/params""")
parser.add_argument(
"--save_outputs", action="store_true", default=False,
help="Whether to save all model outputs, including embeddings, etc."
)
parser.add_argument("--cpus", parser.add_argument("--cpus",
type=int, type=int,
default=12, default=12,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment