Commit ed759cf4 authored by Shenggan's avatar Shenggan
Browse files

fix mask softmax and support inference for arbitrary length sequences

parent 16d10d6a
...@@ -57,6 +57,35 @@ torch.distributed.init_process_group(backend='nccl', init_method='env://') ...@@ -57,6 +57,35 @@ torch.distributed.init_process_group(backend='nccl', init_method='env://')
init_dap(args.dap_size) init_dap(args.dap_size)
``` ```
### Inference
You can use FastFold alongwith OpenFold with `inject_openfold`. This will replace the evoformer in OpenFold with the high performance evoformer from FastFold.
```python
from fastfold.utils import inject_openfold
model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name)
model = inject_openfold(model)
```
For Dynamic Axial Parallelism, you can refer to `./inference.py`. Here is an example of 2 GPUs parallel inference:
```shell
torchrun --nproc_per_node=2 inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--output_dir ./ \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign
```
## Performance Benchmark ## Performance Benchmark
We have included a performance benchmark script in `./benchmark`. You can benchmark the performance of Evoformer using different settings. We have included a performance benchmark script in `./benchmark`. You can benchmark the performance of Evoformer using different settings.
......
...@@ -314,7 +314,7 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa ...@@ -314,7 +314,7 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
#pragma unroll #pragma unroll
for (int i = 0; i < cols_this_thread; i++) { for (int i = 0; i < cols_this_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) { if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * CUDART_INF_F; buf[i] = -1 * 1e9;
} else { } else {
buf[i] = row_input[lane_id * cols_per_thread + i] * scale; buf[i] = row_input[lane_id * cols_per_thread + i] * scale;
} }
...@@ -373,7 +373,7 @@ __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloa ...@@ -373,7 +373,7 @@ __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloa
#pragma unroll #pragma unroll
for (int i = 0; i < cols_this_thread; i++) { for (int i = 0; i < cols_this_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) { if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * CUDART_INF_F; buf[i] = -1 * 10e9;
} else { } else {
buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]) * scale; buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]) * scale;
} }
...@@ -601,7 +601,7 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask, ...@@ -601,7 +601,7 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
#pragma unroll #pragma unroll
for (int i = 0; i < cols_this_thread; i++) { for (int i = 0; i < cols_this_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) { if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * CUDART_INF_F; buf[i] = -1 * 10e9;
} else { } else {
buf[i] = row_input[lane_id * cols_per_thread + i] * scale + buf[i] = row_input[lane_id * cols_per_thread + i] * scale +
bias_ptr[lane_id * cols_per_thread + i]; bias_ptr[lane_id * cols_per_thread + i];
...@@ -662,7 +662,7 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at:: ...@@ -662,7 +662,7 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::
#pragma unroll #pragma unroll
for (int i = 0; i < cols_this_thread; i++) { for (int i = 0; i < cols_this_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) { if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * CUDART_INF_F; buf[i] = -1 * 10e9;
} else { } else {
buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]) * scale; buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]) * scale;
buf[i] += static_cast<float>(bias_ptr[lane_id * cols_per_thread + i]); buf[i] += static_cast<float>(bias_ptr[lane_id * cols_per_thread + i]);
...@@ -743,4 +743,4 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso ...@@ -743,4 +743,4 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso
} }
return grad_input; return grad_input;
} }
\ No newline at end of file
...@@ -6,6 +6,7 @@ import torch.nn as nn ...@@ -6,6 +6,7 @@ import torch.nn as nn
from fastfold.model import MSAStack, OutProductMean, PairStack from fastfold.model import MSAStack, OutProductMean, PairStack
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.distributed.comm import gather, scatter from fastfold.distributed.comm import gather, scatter
from fastfold.distributed import get_tensor_model_parallel_world_size
class EvoformerBlock(nn.Module): class EvoformerBlock(nn.Module):
...@@ -30,17 +31,29 @@ class EvoformerBlock(nn.Module): ...@@ -30,17 +31,29 @@ class EvoformerBlock(nn.Module):
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = get_tensor_model_parallel_world_size()
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block: if self.first_block:
m = m.unsqueeze(0) m = m.unsqueeze(0)
z = z.unsqueeze(0) z = z.unsqueeze(0)
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
m = scatter(m, dim=1) m = scatter(m, dim=1)
z = scatter(z, dim=1) z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0) msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0) pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
m = self.msa_stack(m, z, msa_mask) m = self.msa_stack(m, z, msa_mask)
z = z + self.communication(m, msa_mask) z = z + self.communication(m, msa_mask)
m, work = All_to_All_Async.apply(m, 1, 2) m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask) z = self.pair_stack(z, pair_mask)
...@@ -53,6 +66,9 @@ class EvoformerBlock(nn.Module): ...@@ -53,6 +66,9 @@ class EvoformerBlock(nn.Module):
m = gather(m, dim=0) m = gather(m, dim=0)
z = gather(z, dim=0) z = gather(z, dim=0)
m = m[:, :-padding_size, :]
z = z[:-padding_size, :-padding_size, :]
return m, z return m, z
......
...@@ -14,16 +14,15 @@ ...@@ -14,16 +14,15 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import logging
import os import os
import random import random
import sys import sys
import time import time
from datetime import date
import numpy as np import numpy as np
import torch import torch
import fastfold
import openfold.np.relax.relax as relax import openfold.np.relax.relax as relax
from fastfold.utils import inject_openfold from fastfold.utils import inject_openfold
from openfold.config import model_config from openfold.config import model_config
...@@ -37,6 +36,17 @@ from scripts.utils import add_data_args ...@@ -37,6 +36,17 @@ from scripts.utils import add_data_args
def main(args): def main(args):
# init distributed for Dynamic Axial Parallelism
local_rank = int(os.getenv('LOCAL_RANK', -1))
if local_rank != -1:
distributed_inference_ = True
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
fastfold.distributed.init_dap(torch.distributed.get_world_size())
else:
distributed_inference_ = False
config = model_config(args.model_name) config = model_config(args.model_name)
model = AlphaFold(config) model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name) import_jax_weights_(model, args.param_path, version=args.model_name)
...@@ -44,7 +54,7 @@ def main(args): ...@@ -44,7 +54,7 @@ def main(args):
model = inject_openfold(model) model = inject_openfold(model)
model = model.eval() model = model.eval()
#script_preset_(model) #script_preset_(model)
model = model.to(args.model_device) model = model.cuda()
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
...@@ -78,89 +88,99 @@ def main(args): ...@@ -78,89 +88,99 @@ def main(args):
tags = [l[1:] for l in tags] tags = [l[1:] for l in tags]
for tag, seq in zip(tags, seqs): for tag, seq in zip(tags, seqs):
fasta_path = os.path.join(args.output_dir, "tmp.fasta") batch = [None]
with open(fasta_path, "w") as fp: if (not distributed_inference_) or (torch.distributed.get_rank() == 0):
fp.write(f">{tag}\n{seq}") fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
print("Generating features...") fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join(alignment_dir, tag)
if (args.use_precomputed_alignments is None): print("Generating features...")
if not os.path.exists(local_alignment_dir): local_alignment_dir = os.path.join(alignment_dir, tag)
os.makedirs(local_alignment_dir) if (args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
alignment_runner = data_pipeline.AlignmentRunner( os.makedirs(local_alignment_dir)
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path, alignment_runner = data_pipeline.AlignmentRunner(
hhsearch_binary_path=args.hhsearch_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
uniref90_database_path=args.uniref90_database_path, hhblits_binary_path=args.hhblits_binary_path,
mgnify_database_path=args.mgnify_database_path, hhsearch_binary_path=args.hhsearch_binary_path,
bfd_database_path=args.bfd_database_path, uniref90_database_path=args.uniref90_database_path,
uniclust30_database_path=args.uniclust30_database_path, mgnify_database_path=args.mgnify_database_path,
pdb70_database_path=args.pdb70_database_path, bfd_database_path=args.bfd_database_path,
use_small_bfd=use_small_bfd, uniclust30_database_path=args.uniclust30_database_path,
no_cpus=args.cpus, pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(fasta_path, local_alignment_dir)
feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
alignment_dir=local_alignment_dir)
# Remove temporary FASTA file
os.remove(fasta_path)
processed_feature_dict = feature_processor.process_features(
feature_dict,
mode='predict',
) )
alignment_runner.run(fasta_path, local_alignment_dir)
feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
alignment_dir=local_alignment_dir)
# Remove temporary FASTA file batch = [processed_feature_dict]
os.remove(fasta_path)
processed_feature_dict = feature_processor.process_features( if distributed_inference_:
feature_dict, torch.distributed.broadcast_object_list(batch, src=0)
mode='predict', batch = batch[0]
)
print("Executing model...") print("Executing model...")
batch = processed_feature_dict
with torch.no_grad(): with torch.no_grad():
batch = {k: torch.as_tensor(v, device=args.model_device) for k, v in batch.items()} batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
t = time.perf_counter() t = time.perf_counter()
out = model(batch) out = model(batch)
print(f"Inference time: {time.perf_counter() - t}") print(f"Inference time: {time.perf_counter() - t}")
# Toss out the recycling dimensions --- we don't need them anymore if distributed_inference_:
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) torch.distributed.barrier()
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
if (not distributed_inference_) or (torch.distributed.get_rank() == 0):
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"] plddt = out["plddt"]
mean_plddt = np.mean(plddt) mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1) plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(features=batch, unrelaxed_protein = protein.from_prediction(features=batch,
result=out, result=out,
b_factors=plddt_b_factors) b_factors=plddt_b_factors)
# Save the unrelaxed PDB. # Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir, unrelaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_unrelaxed.pdb') f'{tag}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f: with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein)) f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation( amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"), use_gpu=True,
**config.relax, **config.relax,
) )
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Relax the prediction. # Save the relaxed PDB.
t = time.perf_counter() relaxed_output_path = os.path.join(args.output_dir,
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") f'{tag}_{args.model_name}_relaxed.pdb')
if ("cuda" in args.model_device): with open(relaxed_output_path, 'w') as f:
device_no = args.model_device.split(":")[-1] f.write(relaxed_pdb_str)
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
if visible_devices:
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB. if distributed_inference_:
relaxed_output_path = os.path.join(args.output_dir, f'{tag}_{args.model_name}_relaxed.pdb') torch.distributed.barrier()
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -184,11 +204,6 @@ if __name__ == "__main__": ...@@ -184,11 +204,6 @@ if __name__ == "__main__":
default=os.getcwd(), default=os.getcwd(),
help="""Name of the directory in which to output the prediction""", help="""Name of the directory in which to output the prediction""",
) )
parser.add_argument("--model_device",
type=str,
default="cpu",
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")""")
parser.add_argument("--model_name", parser.add_argument("--model_name",
type=str, type=str,
default="model_1", default="model_1",
...@@ -216,8 +231,4 @@ if __name__ == "__main__": ...@@ -216,8 +231,4 @@ if __name__ == "__main__":
args.param_path = os.path.join("openfold", "resources", "params", args.param_path = os.path.join("openfold", "resources", "params",
"params_" + args.model_name + ".npz") "params_" + args.model_name + ".npz")
if (args.model_device == "cpu" and torch.cuda.is_available()):
logging.warning("""The model is being run on CPU. Consider specifying
--model_device for better performance""")
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment