Unverified Commit 441bae7f authored by LuGY's avatar LuGY Committed by GitHub
Browse files

Add CI and inference unit test (#85)

* add CI, update Dockerfile

* remove useless loop in inference, add some comments to Attention

* update inference test and CI

* fix path

* add pytest for test env

* add einops install

* add cuda cache, loose the condition
parent 27f6ab70
name: Build
on:
pull_request:
types: [synchronize, labeled]
jobs:
build:
name: Build and Test FastFold
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/FastFold' &&
contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/colossalai:0.1.9
options: --gpus all --rm -v /data/scratch/fastfold:/data/scratch/fastfold
timeout-minutes: 40
steps:
- uses: actions/checkout@v2
with:
repository: hpcaitech/FastFold
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Install FastFold
run: |
[ ! -z "$(ls -A /github/home/fastfold_cache/)" ] && cp -r /github/home/fastfold_cache/* /__w/FastFold/FastFold/
pip install biopython==1.79 dm-tree==0.1.6 ml-collections==0.1.0 scipy==1.7.1 pandas pytest
pip install -e .
cp -r /__w/FastFold/FastFold/build /github/home/fastfold_cache/
cp /__w/FastFold/FastFold/*.so /github/home/fastfold_cache/
- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest tests
env:
NCCL_SHM_DISABLE: 1
......@@ -59,7 +59,7 @@ jobs:
fetch-depth: 0
- name: Copy scripts and checkout
run: |
cp -r ./.github/workflows/* ./
cp -r ./.github/workflows/scripts/build* ./
ln -s /github/home/pip_wheels ./pip_wheels
git checkout $git_ref
env:
......
......@@ -57,7 +57,7 @@ def all_wheel_info():
wheel_info[torch_version][cuda_version][python_version] = dict(url=url)
return wheel_info
def build_colossalai(wheel_info):
def build_fastfold(wheel_info):
cuda_version_major, cuda_version_minor = get_cuda_bare_metal_version()
cuda_version_on_host = f'{cuda_version_major}.{cuda_version_minor}'
......@@ -91,7 +91,7 @@ def main():
if key not in torch_versions:
wheel_info.pop(key)
build_colossalai(wheel_info)
build_fastfold(wheel_info)
if __name__ == '__main__':
main()
FROM hpcaitech/cuda-conda:11.3
FROM hpcaitech/colossalai:0.1.8
RUN conda install openmm=7.7.0 pdbfixer -c conda-forge -y \
&& conda install hmmer==3.3.2 hhsuite=3.3.0 kalign2=2.04 -c bioconda -y
# install dependency
RUN yum install -y patch
RUN conda install pytorch==1.10.0 torchvision torchaudio cudatoolkit=11.3 -c pytorch \
&& conda install setuptools=59.5.0 openmm=7.7.0 pdbfixer -c conda-forge \
&& conda install hmmer==3.3.2 hhsuite=3.3.0 kalign2=2.04 -c bioconda
RUN pip install biopython==1.79 dm-tree==0.1.6 ml-collections==0.1.0 numpy==1.21.2 \
PyYAML==5.4.1 requests==2.26.0 scipy==1.7.1 tqdm==4.62.2 typing-extensions==3.10.0.2 einops ray pyarrow
RUN pip install colossalai==0.1.8+torch1.10cu11.3 -f https://release.colossalai.org
RUN pip install biopython==1.79 dm-tree==0.1.6 ml-collections==0.1.0 \
scipy==1.7.1 ray pyarrow pandas einops
# prepare environment
Run git clone https://github.com/hpcaitech/FastFold.git\
......
......@@ -287,7 +287,7 @@ class SelfAttention(nn.Module):
def forward(self, in_data, mask, nonbatched_bias=None):
"""
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
:param mask: None or [batch_size1, batch_size2, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
......@@ -318,8 +318,13 @@ class SelfAttention(nn.Module):
logits = torch.matmul(q, k.transpose(-1, -2))
if nonbatched_bias is not None:
# logits += bias.unsqueeze(1)
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# weights = torch.nn.functional.softmax(logits, -1)
weights = fused_softmax(logits, mask_part, bias.unsqueeze(1))
else:
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# weights = torch.nn.functional.softmax(logits, -1)
weights = fused_softmax(logits, mask_part)
weighted_avg = torch.matmul(weights, v)
......
......@@ -350,105 +350,105 @@ def inference_monomer_model(args):
with open(args.fasta_path, "r") as fp:
fasta = fp.read()
seqs, tags = parse_fasta(fasta)
seq, tag = seqs[0], tags[0]
for tag, seq in zip(tags, seqs):
print(f"tag:{tag}\nseq[{len(seq)}]:{seq}")
batch = [None]
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
print("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag)
print(f"tag:{tag}\nseq[{len(seq)}]:{seq}")
batch = [None]
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
if (args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
if args.enable_workflow:
print("Running alignment with ray workflow...")
alignment_data_workflow_runner = FastFoldDataWorkFlow(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
t = time.perf_counter()
alignment_data_workflow_runner.run(fasta_path, alignment_dir=local_alignment_dir)
print(f"Alignment data workflow time: {time.perf_counter() - t}")
else:
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(fasta_path, local_alignment_dir)
feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
alignment_dir=local_alignment_dir)
print("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag)
# Remove temporary FASTA file
os.remove(fasta_path)
if (args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
if args.enable_workflow:
print("Running alignment with ray workflow...")
alignment_data_workflow_runner = FastFoldDataWorkFlow(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
t = time.perf_counter()
alignment_data_workflow_runner.run(fasta_path, alignment_dir=local_alignment_dir)
print(f"Alignment data workflow time: {time.perf_counter() - t}")
else:
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(fasta_path, local_alignment_dir)
feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
alignment_dir=local_alignment_dir)
# Remove temporary FASTA file
os.remove(fasta_path)
processed_feature_dict = feature_processor.process_features(
feature_dict,
mode='predict',
)
processed_feature_dict = feature_processor.process_features(
feature_dict,
mode='predict',
)
batch = processed_feature_dict
batch = processed_feature_dict
manager = mp.Manager()
result_q = manager.Queue()
torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args))
manager = mp.Manager()
result_q = manager.Queue()
torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args))
out = result_q.get()
out = result_q.get()
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(features=batch,
result=out,
b_factors=plddt_b_factors)
unrelaxed_protein = protein.from_prediction(features=batch,
result=out,
b_factors=plddt_b_factors)
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation(
use_gpu=True,
**config.relax,
)
amber_relaxer = relax.AmberRelaxation(
use_gpu=True,
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if __name__ == "__main__":
......
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def random_template_feats(n_templ, n, batch_size=None):
b = []
if batch_size is not None:
b.append(batch_size)
batch = {
"template_mask": np.random.randint(0, 2, (*b, n_templ)),
"template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)),
"template_pseudo_beta": np.random.rand(*b, n_templ, n, 3),
"template_aatype": np.random.randint(0, 22, (*b, n_templ, n)),
"template_all_atom_mask": np.random.randint(
0, 2, (*b, n_templ, n, 37)
),
"template_all_atom_positions":
np.random.rand(*b, n_templ, n, 37, 3) * 10,
"template_torsion_angles_sin_cos":
np.random.rand(*b, n_templ, n, 7, 2),
"template_alt_torsion_angles_sin_cos":
np.random.rand(*b, n_templ, n, 7, 2),
"template_torsion_angles_mask":
np.random.rand(*b, n_templ, n, 7),
}
batch = {k: v.astype(np.float32) for k, v in batch.items()}
batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
return batch
def random_extra_msa_feats(n_extra, n, batch_size=None):
b = []
if batch_size is not None:
b.append(batch_size)
batch = {
"extra_msa": np.random.randint(0, 22, (*b, n_extra, n)).astype(
np.int64
),
"extra_has_deletion": np.random.randint(0, 2, (*b, n_extra, n)).astype(
np.float32
),
"extra_deletion_value": np.random.rand(*b, n_extra, n).astype(
np.float32
),
"extra_msa_mask": np.random.randint(0, 2, (*b, n_extra, n)).astype(
np.float32
),
}
return batch
\ No newline at end of file
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import os
import copy
import pytest
import torch
import ml_collections as mlc
import pickle
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.model.hub import AlphaFold
from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size
from fastfold.utils import inject_fastnn
from test_data_utils import random_extra_msa_feats, random_template_feats
from fastfold.data import data_transforms
from fastfold.utils.tensor_utils import tensor_tree_map
from fastfold.utils.import_weights import import_jax_weights_
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 2])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace):
run_func = partial(run_dist, world_size=world_size, chunk_size=chunk_size, inplace=inplace)
mp.spawn(run_func, nprocs=world_size)
consts = mlc.ConfigDict(
{
"n_res": 11,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
}
)
def inference():
def run_dist(rank, world_size, chunk_size, inplace):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
inference(chunk_size, inplace)
n_seq = consts.n_seq
n_templ = consts.n_templ
n_res = consts.n_res
n_extra_seq = consts.n_extra
def inference(chunk_size, inplace):
config = model_config('model_1')
config.globals.chunk_size = chunk_size
config.globals.inplace = False
model = AlphaFold(config)
model = inject_fastnn(model)
import_jax_weights_(model, '/data/scratch/fastfold/weight.npz')
model.eval()
model.cuda()
set_chunk_size(model.globals.chunk_size)
batch = {}
tf = torch.randint(config.model.input_embedder.tf_dim - 1, size=(n_res,))
batch["target_feat"] = torch.nn.functional.one_hot(
tf, config.model.input_embedder.tf_dim).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, config.model.input_embedder.msa_dim))
t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint(low=0, high=2, size=(n_seq, n_res)).float()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.)
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, config.data.common.max_recycling_iters))
batch = tensor_tree_map(add_recycling_dims, batch)
fastmodel = copy.deepcopy(model)
fastmodel = inject_fastnn(fastmodel)
fastmodel.eval()
fastmodel.cuda()
set_chunk_size(model.globals.chunk_size)
batch = pickle.load(open('/data/scratch/fastfold/mono_batch.pkl', 'rb'))
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
fastbatch = copy.deepcopy(batch)
with torch.no_grad():
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
t = time.perf_counter()
out = model(batch)
print(f"Inference time: {time.perf_counter() - t}")
config.globals.inplace = inplace
fastout = fastmodel(fastbatch)
if __name__ == "__main__":
inference()
print("Inference Test Passed!")
\ No newline at end of file
pos_dif = torch.max(torch.abs(fastout["final_atom_positions"] - out["final_atom_positions"]))
assert pos_dif < 1.5, f"Test failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {pos_dif}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment