Commit 267b1bfd authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add warning to run script

parent a3a27143
...@@ -118,7 +118,7 @@ def main(args): ...@@ -118,7 +118,7 @@ def main(args):
t = time.time() t = time.time()
out = model(batch) out = model(batch)
logging.info(f"Inference time: {time.time() - t}") logging.info(f"Inference time: {time.time() - t}")
# Toss out the recycling dimensions --- we don't need them anymore # Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out) out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
...@@ -129,7 +129,7 @@ def main(args): ...@@ -129,7 +129,7 @@ def main(args):
plddt_b_factors = np.repeat( plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1 plddt[..., None], residue_constants.atom_type_num, axis=-1
) )
unrelaxed_protein = protein.from_prediction( unrelaxed_protein = protein.from_prediction(
features=batch, features=batch,
result=out, result=out,
...@@ -182,7 +182,7 @@ if __name__ == "__main__": ...@@ -182,7 +182,7 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"--cpus", type=int, default=4, "--cpus", type=int, default=4,
help="""Number of CPUs to use to run alignment tools""" help="""Number of CPUs with which to run alignment tools"""
) )
parser.add_argument( parser.add_argument(
'--preset', type=str, default='full_dbs', '--preset', type=str, default='full_dbs',
...@@ -200,6 +200,12 @@ if __name__ == "__main__": ...@@ -200,6 +200,12 @@ if __name__ == "__main__":
"params_" + args.model_name + ".npz" "params_" + args.model_name + ".npz"
) )
if(args.model_device == "cpu" and torch.cuda.is_available()):
logging.warning(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)
if(args.bfd_database_path is None and if(args.bfd_database_path is None and
args.small_bfd_database_path is None): args.small_bfd_database_path is None):
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment