"vscode:/vscode.git/clone" did not exist on "ee716e43466a62dd3a671cbf768f01ecc118a8aa"
Commit 9924e7be authored by Shenggan's avatar Shenggan
Browse files

use torch.multiprocess to launch multi-gpu inference

parent f44557ed
......@@ -72,8 +72,9 @@ model = inject_fastnn(model)
For Dynamic Axial Parallelism, you can refer to `./inference.py`. Here is an example of 2 GPUs parallel inference:
```shell
torchrun --nproc_per_node=2 inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
--output_dir ./ \
--gpus 2 \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \
......
......@@ -34,7 +34,7 @@ def init_dap(tensor_model_parallel_size_=None):
set_missing_distributed_environ('RANK', 0)
set_missing_distributed_environ('LOCAL_RANK', 0)
set_missing_distributed_environ('MASTER_ADDR', "localhost")
set_missing_distributed_environ('MASTER_PORT', -1)
set_missing_distributed_environ('MASTER_PORT', 18417)
colossalai.launch_from_torch(
config={"parallel": dict(tensor=dict(size=tensor_model_parallel_size_))})
......@@ -22,6 +22,7 @@ from datetime import date
import numpy as np
import torch
import torch.multiprocessing as mp
from fastfold.model.hub import AlphaFold
import fastfold
......@@ -73,19 +74,39 @@ def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument('--release_dates_path', type=str, default=None)
def main(args):
def inference_model(rank, world_size, result_q, batch, args):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
torch.cuda.set_device(rank)
config = model_config(args.model_name)
model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name)
model = inject_fastnn(model)
model = model.eval()
#script_preset_(model)
model = model.cuda()
with torch.no_grad():
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
t = time.perf_counter()
out = model(batch)
print(f"Inference time: {time.perf_counter() - t}")
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
result_q.put(out)
torch.distributed.barrier()
torch.cuda.synchronize()
def main(args):
config = model_config(args.model_name)
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
......@@ -124,7 +145,7 @@ def main(args):
for tag, seq in zip(tags, seqs):
batch = [None]
if torch.distributed.get_rank() == 0:
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
......@@ -160,27 +181,16 @@ def main(args):
mode='predict',
)
batch = [processed_feature_dict]
torch.distributed.broadcast_object_list(batch, src=0)
batch = batch[0]
print("Executing model...")
with torch.no_grad():
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
batch = processed_feature_dict
t = time.perf_counter()
out = model(batch)
print(f"Inference time: {time.perf_counter() - t}")
manager = mp.Manager()
result_q = manager.Queue()
torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args))
torch.distributed.barrier()
out = result_q.get()
if torch.distributed.get_rank() == 0:
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
......@@ -213,8 +223,6 @@ def main(args):
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
torch.distributed.barrier()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......@@ -252,6 +260,10 @@ if __name__ == "__main__":
type=int,
default=12,
help="""Number of CPUs with which to run alignment tools""")
parser.add_argument("--gpus",
type=int,
default=1,
help="""Number of GPUs with which to run inference""")
parser.add_argument('--preset',
type=str,
default='full_dbs',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment