Commit 9924e7be authored by Shenggan's avatar Shenggan
Browse files

use torch.multiprocess to launch multi-gpu inference

parent f44557ed
...@@ -72,8 +72,9 @@ model = inject_fastnn(model) ...@@ -72,8 +72,9 @@ model = inject_fastnn(model)
For Dynamic Axial Parallelism, you can refer to `./inference.py`. Here is an example of 2 GPUs parallel inference: For Dynamic Axial Parallelism, you can refer to `./inference.py`. Here is an example of 2 GPUs parallel inference:
```shell ```shell
torchrun --nproc_per_node=2 inference.py target.fasta data/pdb_mmcif/mmcif_files/ \ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
--output_dir ./ \ --output_dir ./ \
--gpus 2 \
--uniref90_database_path data/uniref90/uniref90.fasta \ --uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \ --mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \ --pdb70_database_path data/pdb70/pdb70 \
......
...@@ -34,7 +34,7 @@ def init_dap(tensor_model_parallel_size_=None): ...@@ -34,7 +34,7 @@ def init_dap(tensor_model_parallel_size_=None):
set_missing_distributed_environ('RANK', 0) set_missing_distributed_environ('RANK', 0)
set_missing_distributed_environ('LOCAL_RANK', 0) set_missing_distributed_environ('LOCAL_RANK', 0)
set_missing_distributed_environ('MASTER_ADDR', "localhost") set_missing_distributed_environ('MASTER_ADDR', "localhost")
set_missing_distributed_environ('MASTER_PORT', -1) set_missing_distributed_environ('MASTER_PORT', 18417)
colossalai.launch_from_torch( colossalai.launch_from_torch(
config={"parallel": dict(tensor=dict(size=tensor_model_parallel_size_))}) config={"parallel": dict(tensor=dict(size=tensor_model_parallel_size_))})
...@@ -22,6 +22,7 @@ from datetime import date ...@@ -22,6 +22,7 @@ from datetime import date
import numpy as np import numpy as np
import torch import torch
import torch.multiprocessing as mp
from fastfold.model.hub import AlphaFold from fastfold.model.hub import AlphaFold
import fastfold import fastfold
...@@ -73,19 +74,39 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -73,19 +74,39 @@ def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument('--release_dates_path', type=str, default=None) parser.add_argument('--release_dates_path', type=str, default=None)
def main(args): def inference_model(rank, world_size, result_q, batch, args):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism # init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap() fastfold.distributed.init_dap()
torch.cuda.set_device(rank)
config = model_config(args.model_name) config = model_config(args.model_name)
model = AlphaFold(config) model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name) import_jax_weights_(model, args.param_path, version=args.model_name)
model = inject_fastnn(model) model = inject_fastnn(model)
model = model.eval() model = model.eval()
#script_preset_(model)
model = model.cuda() model = model.cuda()
with torch.no_grad():
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
t = time.perf_counter()
out = model(batch)
print(f"Inference time: {time.perf_counter() - t}")
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
result_q.put(out)
torch.distributed.barrier()
torch.cuda.synchronize()
def main(args):
config = model_config(args.model_name)
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
...@@ -124,7 +145,7 @@ def main(args): ...@@ -124,7 +145,7 @@ def main(args):
for tag, seq in zip(tags, seqs): for tag, seq in zip(tags, seqs):
batch = [None] batch = [None]
if torch.distributed.get_rank() == 0:
fasta_path = os.path.join(args.output_dir, "tmp.fasta") fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp: with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}") fp.write(f">{tag}\n{seq}")
...@@ -160,27 +181,16 @@ def main(args): ...@@ -160,27 +181,16 @@ def main(args):
mode='predict', mode='predict',
) )
batch = [processed_feature_dict] batch = processed_feature_dict
torch.distributed.broadcast_object_list(batch, src=0)
batch = batch[0]
print("Executing model...")
with torch.no_grad():
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
t = time.perf_counter() manager = mp.Manager()
out = model(batch) result_q = manager.Queue()
print(f"Inference time: {time.perf_counter() - t}") torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args))
torch.distributed.barrier() out = result_q.get()
if torch.distributed.get_rank() == 0:
# Toss out the recycling dimensions --- we don't need them anymore # Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"] plddt = out["plddt"]
mean_plddt = np.mean(plddt) mean_plddt = np.mean(plddt)
...@@ -213,8 +223,6 @@ def main(args): ...@@ -213,8 +223,6 @@ def main(args):
with open(relaxed_output_path, 'w') as f: with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str) f.write(relaxed_pdb_str)
torch.distributed.barrier()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -252,6 +260,10 @@ if __name__ == "__main__": ...@@ -252,6 +260,10 @@ if __name__ == "__main__":
type=int, type=int,
default=12, default=12,
help="""Number of CPUs with which to run alignment tools""") help="""Number of CPUs with which to run alignment tools""")
parser.add_argument("--gpus",
type=int,
default=1,
help="""Number of GPUs with which to run inference""")
parser.add_argument('--preset', parser.add_argument('--preset',
type=str, type=str,
default='full_dbs', default='full_dbs',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment