Unverified Commit 441bae7f authored by LuGY's avatar LuGY Committed by GitHub
Browse files

Add CI and inference unit test (#85)

* add CI, update Dockerfile

* remove useless loop in inference, add some comments to Attention

* update inference test and CI

* fix path

* add pytest for test env

* add einops install

* add cuda cache, loose the condition
parent 27f6ab70
name: Build
on:
pull_request:
types: [synchronize, labeled]
jobs:
build:
name: Build and Test FastFold
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/FastFold' &&
contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/colossalai:0.1.9
options: --gpus all --rm -v /data/scratch/fastfold:/data/scratch/fastfold
timeout-minutes: 40
steps:
- uses: actions/checkout@v2
with:
repository: hpcaitech/FastFold
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Install FastFold
run: |
[ ! -z "$(ls -A /github/home/fastfold_cache/)" ] && cp -r /github/home/fastfold_cache/* /__w/FastFold/FastFold/
pip install biopython==1.79 dm-tree==0.1.6 ml-collections==0.1.0 scipy==1.7.1 pandas pytest
pip install -e .
cp -r /__w/FastFold/FastFold/build /github/home/fastfold_cache/
cp /__w/FastFold/FastFold/*.so /github/home/fastfold_cache/
- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest tests
env:
NCCL_SHM_DISABLE: 1
...@@ -59,7 +59,7 @@ jobs: ...@@ -59,7 +59,7 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- name: Copy scripts and checkout - name: Copy scripts and checkout
run: | run: |
cp -r ./.github/workflows/* ./ cp -r ./.github/workflows/scripts/build* ./
ln -s /github/home/pip_wheels ./pip_wheels ln -s /github/home/pip_wheels ./pip_wheels
git checkout $git_ref git checkout $git_ref
env: env:
......
...@@ -57,7 +57,7 @@ def all_wheel_info(): ...@@ -57,7 +57,7 @@ def all_wheel_info():
wheel_info[torch_version][cuda_version][python_version] = dict(url=url) wheel_info[torch_version][cuda_version][python_version] = dict(url=url)
return wheel_info return wheel_info
def build_colossalai(wheel_info): def build_fastfold(wheel_info):
cuda_version_major, cuda_version_minor = get_cuda_bare_metal_version() cuda_version_major, cuda_version_minor = get_cuda_bare_metal_version()
cuda_version_on_host = f'{cuda_version_major}.{cuda_version_minor}' cuda_version_on_host = f'{cuda_version_major}.{cuda_version_minor}'
...@@ -91,7 +91,7 @@ def main(): ...@@ -91,7 +91,7 @@ def main():
if key not in torch_versions: if key not in torch_versions:
wheel_info.pop(key) wheel_info.pop(key)
build_colossalai(wheel_info) build_fastfold(wheel_info)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
FROM hpcaitech/cuda-conda:11.3 FROM hpcaitech/colossalai:0.1.8
RUN conda install openmm=7.7.0 pdbfixer -c conda-forge -y \
&& conda install hmmer==3.3.2 hhsuite=3.3.0 kalign2=2.04 -c bioconda -y
# install dependency RUN pip install biopython==1.79 dm-tree==0.1.6 ml-collections==0.1.0 \
RUN yum install -y patch scipy==1.7.1 ray pyarrow pandas einops
RUN conda install pytorch==1.10.0 torchvision torchaudio cudatoolkit=11.3 -c pytorch \
&& conda install setuptools=59.5.0 openmm=7.7.0 pdbfixer -c conda-forge \
&& conda install hmmer==3.3.2 hhsuite=3.3.0 kalign2=2.04 -c bioconda
RUN pip install biopython==1.79 dm-tree==0.1.6 ml-collections==0.1.0 numpy==1.21.2 \
PyYAML==5.4.1 requests==2.26.0 scipy==1.7.1 tqdm==4.62.2 typing-extensions==3.10.0.2 einops ray pyarrow
RUN pip install colossalai==0.1.8+torch1.10cu11.3 -f https://release.colossalai.org
# prepare environment # prepare environment
Run git clone https://github.com/hpcaitech/FastFold.git\ Run git clone https://github.com/hpcaitech/FastFold.git\
......
...@@ -287,7 +287,7 @@ class SelfAttention(nn.Module): ...@@ -287,7 +287,7 @@ class SelfAttention(nn.Module):
def forward(self, in_data, mask, nonbatched_bias=None): def forward(self, in_data, mask, nonbatched_bias=None):
""" """
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim] :param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv] :param mask: None or [batch_size1, batch_size2, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
""" """
...@@ -318,8 +318,13 @@ class SelfAttention(nn.Module): ...@@ -318,8 +318,13 @@ class SelfAttention(nn.Module):
logits = torch.matmul(q, k.transpose(-1, -2)) logits = torch.matmul(q, k.transpose(-1, -2))
if nonbatched_bias is not None: if nonbatched_bias is not None:
# logits += bias.unsqueeze(1)
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# weights = torch.nn.functional.softmax(logits, -1)
weights = fused_softmax(logits, mask_part, bias.unsqueeze(1)) weights = fused_softmax(logits, mask_part, bias.unsqueeze(1))
else: else:
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# weights = torch.nn.functional.softmax(logits, -1)
weights = fused_softmax(logits, mask_part) weights = fused_softmax(logits, mask_part)
weighted_avg = torch.matmul(weights, v) weighted_avg = torch.matmul(weights, v)
......
...@@ -350,8 +350,8 @@ def inference_monomer_model(args): ...@@ -350,8 +350,8 @@ def inference_monomer_model(args):
with open(args.fasta_path, "r") as fp: with open(args.fasta_path, "r") as fp:
fasta = fp.read() fasta = fp.read()
seqs, tags = parse_fasta(fasta) seqs, tags = parse_fasta(fasta)
seq, tag = seqs[0], tags[0]
for tag, seq in zip(tags, seqs):
print(f"tag:{tag}\nseq[{len(seq)}]:{seq}") print(f"tag:{tag}\nseq[{len(seq)}]:{seq}")
batch = [None] batch = [None]
......
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def random_template_feats(n_templ, n, batch_size=None):
b = []
if batch_size is not None:
b.append(batch_size)
batch = {
"template_mask": np.random.randint(0, 2, (*b, n_templ)),
"template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)),
"template_pseudo_beta": np.random.rand(*b, n_templ, n, 3),
"template_aatype": np.random.randint(0, 22, (*b, n_templ, n)),
"template_all_atom_mask": np.random.randint(
0, 2, (*b, n_templ, n, 37)
),
"template_all_atom_positions":
np.random.rand(*b, n_templ, n, 37, 3) * 10,
"template_torsion_angles_sin_cos":
np.random.rand(*b, n_templ, n, 7, 2),
"template_alt_torsion_angles_sin_cos":
np.random.rand(*b, n_templ, n, 7, 2),
"template_torsion_angles_mask":
np.random.rand(*b, n_templ, n, 7),
}
batch = {k: v.astype(np.float32) for k, v in batch.items()}
batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
return batch
def random_extra_msa_feats(n_extra, n, batch_size=None):
b = []
if batch_size is not None:
b.append(batch_size)
batch = {
"extra_msa": np.random.randint(0, 22, (*b, n_extra, n)).astype(
np.int64
),
"extra_has_deletion": np.random.randint(0, 2, (*b, n_extra, n)).astype(
np.float32
),
"extra_deletion_value": np.random.rand(*b, n_extra, n).astype(
np.float32
),
"extra_msa_mask": np.random.randint(0, 2, (*b, n_extra, n)).astype(
np.float32
),
}
return batch
\ No newline at end of file
# Copyright 2021 AlQuraishi Laboratory import os
# import copy
# Licensed under the Apache License, Version 2.0 (the "License"); import pytest
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import torch import torch
import ml_collections as mlc import pickle
import torch.multiprocessing as mp
from functools import partial
import fastfold import fastfold
from fastfold.model.hub import AlphaFold from fastfold.model.hub import AlphaFold
from fastfold.config import model_config from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size from fastfold.model.fastnn import set_chunk_size
from fastfold.utils import inject_fastnn from fastfold.utils import inject_fastnn
from test_data_utils import random_extra_msa_feats, random_template_feats from fastfold.utils.import_weights import import_jax_weights_
from fastfold.data import data_transforms
from fastfold.utils.tensor_utils import tensor_tree_map
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 2])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace):
run_func = partial(run_dist, world_size=world_size, chunk_size=chunk_size, inplace=inplace)
mp.spawn(run_func, nprocs=world_size)
consts = mlc.ConfigDict(
{
"n_res": 11,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
}
)
def inference(): def run_dist(rank, world_size, chunk_size, inplace):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap() fastfold.distributed.init_dap()
inference(chunk_size, inplace)
n_seq = consts.n_seq
n_templ = consts.n_templ
n_res = consts.n_res
n_extra_seq = consts.n_extra
def inference(chunk_size, inplace):
config = model_config('model_1') config = model_config('model_1')
config.globals.chunk_size = chunk_size
config.globals.inplace = False
model = AlphaFold(config) model = AlphaFold(config)
model = inject_fastnn(model) import_jax_weights_(model, '/data/scratch/fastfold/weight.npz')
model.eval() model.eval()
model.cuda() model.cuda()
set_chunk_size(model.globals.chunk_size) fastmodel = copy.deepcopy(model)
fastmodel = inject_fastnn(fastmodel)
batch = {} fastmodel.eval()
tf = torch.randint(config.model.input_embedder.tf_dim - 1, size=(n_res,)) fastmodel.cuda()
batch["target_feat"] = torch.nn.functional.one_hot(
tf, config.model.input_embedder.tf_dim).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, config.model.input_embedder.msa_dim))
t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint(low=0, high=2, size=(n_seq, n_res)).float()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.)
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, config.data.common.max_recycling_iters))
batch = tensor_tree_map(add_recycling_dims, batch)
set_chunk_size(model.globals.chunk_size)
batch = pickle.load(open('/data/scratch/fastfold/mono_batch.pkl', 'rb'))
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
fastbatch = copy.deepcopy(batch)
with torch.no_grad(): with torch.no_grad():
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
t = time.perf_counter()
out = model(batch) out = model(batch)
print(f"Inference time: {time.perf_counter() - t}") config.globals.inplace = inplace
fastout = fastmodel(fastbatch)
if __name__ == "__main__": pos_dif = torch.max(torch.abs(fastout["final_atom_positions"] - out["final_atom_positions"]))
inference() assert pos_dif < 1.5, f"Test failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {pos_dif}"
print("Inference Test Passed!")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment