Commit 9924e7be authored by Shenggan's avatar Shenggan
Browse files

use torch.multiprocess to launch multi-gpu inference

parent f44557ed
......@@ -72,8 +72,9 @@ model = inject_fastnn(model)
For Dynamic Axial Parallelism, you can refer to `./inference.py`. Here is an example of 2 GPUs parallel inference:
```shell
torchrun --nproc_per_node=2 inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
--output_dir ./ \
--gpus 2 \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \
......
......@@ -34,7 +34,7 @@ def init_dap(tensor_model_parallel_size_=None):
set_missing_distributed_environ('RANK', 0)
set_missing_distributed_environ('LOCAL_RANK', 0)
set_missing_distributed_environ('MASTER_ADDR', "localhost")
set_missing_distributed_environ('MASTER_PORT', -1)
set_missing_distributed_environ('MASTER_PORT', 18417)
colossalai.launch_from_torch(
config={"parallel": dict(tensor=dict(size=tensor_model_parallel_size_))})
......@@ -22,6 +22,7 @@ from datetime import date
import numpy as np
import torch
import torch.multiprocessing as mp
from fastfold.model.hub import AlphaFold
import fastfold
......@@ -73,19 +74,39 @@ def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument('--release_dates_path', type=str, default=None)
def main(args):
def inference_model(rank, world_size, result_q, batch, args):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
torch.cuda.set_device(rank)
config = model_config(args.model_name)
model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name)
model = inject_fastnn(model)
model = model.eval()
#script_preset_(model)
model = model.cuda()
with torch.no_grad():
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
t = time.perf_counter()
out = model(batch)
print(f"Inference time: {time.perf_counter() - t}")
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
result_q.put(out)
torch.distributed.barrier()
torch.cuda.synchronize()
def main(args):
config = model_config(args.model_name)
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
......@@ -124,96 +145,83 @@ def main(args):
for tag, seq in zip(tags, seqs):
batch = [None]
if torch.distributed.get_rank() == 0:
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
print("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag)
if (args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(fasta_path, local_alignment_dir)
feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
alignment_dir=local_alignment_dir)
# Remove temporary FASTA file
os.remove(fasta_path)
processed_feature_dict = feature_processor.process_features(
feature_dict,
mode='predict',
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
print("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag)
if (args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(fasta_path, local_alignment_dir)
batch = [processed_feature_dict]
feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
alignment_dir=local_alignment_dir)
torch.distributed.broadcast_object_list(batch, src=0)
batch = batch[0]
# Remove temporary FASTA file
os.remove(fasta_path)
print("Executing model...")
processed_feature_dict = feature_processor.process_features(
feature_dict,
mode='predict',
)
with torch.no_grad():
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
batch = processed_feature_dict
t = time.perf_counter()
out = model(batch)
print(f"Inference time: {time.perf_counter() - t}")
manager = mp.Manager()
result_q = manager.Queue()
torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args))
torch.distributed.barrier()
out = result_q.get()
if torch.distributed.get_rank() == 0:
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(features=batch,
result=out,
b_factors=plddt_b_factors)
unrelaxed_protein = protein.from_prediction(features=batch,
result=out,
b_factors=plddt_b_factors)
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation(
use_gpu=True,
**config.relax,
)
amber_relaxer = relax.AmberRelaxation(
use_gpu=True,
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
torch.distributed.barrier()
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if __name__ == "__main__":
......@@ -252,6 +260,10 @@ if __name__ == "__main__":
type=int,
default=12,
help="""Number of CPUs with which to run alignment tools""")
parser.add_argument("--gpus",
type=int,
default=1,
help="""Number of GPUs with which to run inference""")
parser.add_argument('--preset',
type=str,
default='full_dbs',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment