Commit a2ab7ab7 authored by Sam DeLuca's avatar Sam DeLuca
Browse files

more fixes

parent ce042962
......@@ -237,14 +237,13 @@ def load_models_from_command_line(args, config):
for path in args.openfold_checkpoint_path.split(","):
model = AlphaFold(config)
model = model.eval()
checkpoint_basename = None
if os.path.isdir(path):
# A DeepSpeed checkpoint
checkpoint_basename = os.path.splitext(
os.path.basename(
os.path.normpath(path)
)
)[0]
if os.path.isdir(path):
# A DeepSpeed checkpoint
ckpt_path = os.path.join(
args.output_dir,
checkpoint_basename + ".pt",
......@@ -313,7 +312,7 @@ def main(args):
for fasta_file in os.listdir(args.fasta_dir):
batch, tag, feature_dict = generate_batch(
batch_data = generate_batch(
fasta_file,
args.fasta_dir,
alignment_dir,
......@@ -321,6 +320,12 @@ def main(args):
feature_processor,
prediction_dir)
if batch_data is None:
# this file has already been processed
continue
batch, tag, feature_dict = batch_data
for model, model_version in load_models_from_command_line(args, config):
working_batch = deepcopy(batch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment