Unverified Commit 4dc719f7 authored by Leo Zhao's avatar Leo Zhao Committed by GitHub
Browse files

optimize code for bwd performance and refine code. (#145)

* optimize code for bwd performance and refine code.

* refine README to add test SW version
parent 95150c38
......@@ -210,11 +210,12 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
### Inference or Training on Intel Habana
To run AlphaFold inference or training on Intel Habana, you can follow the instructions in the [Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/) to set up your environment on Amazon EC2 DL1 instances or on-premise environments.
To run AlphaFold inference or training on Intel Habana, you can follow the instructions in the [Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/) to set up your environment on Amazon EC2 DL1 instances or on-premise environments, and please use SynapseAI R1.7.1 to test as it was verified internally.
Once you have prepared your dataset and installed fastfold, you can use the following scripts:
```shell
cd fastfold/habana/fastnn/custom_op/; python setup.py build (this is for Gaudi, for Gaudi2 please use setup2.py) ; cd -
bash habana/inference.sh
bash habana/train.sh
```
......
from .comm import (All_to_All, _gather, _reduce, _split, col_to_row, copy,
gather, reduce, row_to_col, scatter)
from .core import init_dist
from .core import init_dist, get_data_parallel_world_size
__all__ = [
'init_dist', '_reduce', '_split', '_gather', 'copy', 'scatter', 'reduce', 'gather',
'init_dist', 'get_data_parallel_world_size', '_reduce', '_split', '_gather', 'copy', 'scatter', 'reduce', 'gather',
'col_to_row', 'row_to_col', 'All_to_All'
]
......@@ -18,7 +18,7 @@ from typing import Tuple, Any, Sequence, Callable, Optional
import numpy as np
import torch
import fastfold.habana as habana
def rot_matmul(
a: torch.Tensor,
......@@ -34,6 +34,19 @@ def rot_matmul(
Returns:
The product ab
"""
if habana.is_habana():
if len(a.shape) == 4 and a.shape[1] == 1:
aa = a.permute(0, 1, 3, 2)
bb = b.permute(0, 1, 3, 2)
cc = bb @ aa
cc = cc.permute(0, 1, 3, 2)
return cc
elif len(a.shape) == 4 and a.shape[1] != 1:
pass
else:
cc = a @ b
return cc
row_1 = torch.stack(
[
a[..., 0, 0] * b[..., 0, 0]
......@@ -94,6 +107,20 @@ def rot_vec_mul(
Returns:
[*, 3] rotated coordinates
"""
if habana.is_habana():
cont = True
if len(t.shape) == 4 and t.shape[1] == 1:
cont = False
elif len(t.shape) == 3 and t.shape[0] != r.shape[0] and t.shape[0] == 1:
cont = False
if cont:
tt = t.unsqueeze(-2)
rr = r.transpose(-2, -1)
cc = tt @ rr
cc = cc.squeeze(-2)
return cc
x = t[..., 0]
y = t[..., 1]
z = t[..., 2]
......
export GC_KERNEL_PATH=./fastfold/habana/fastnn/custom_op/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
export PYTHONPATH=./:$PYTHONPATH
# add '--gpus [N]' to use N gpus for inference
# add '--enable_workflow' to use parallel workflow for data processing
# add '--use_precomputed_alignments [path_to_alignments]' to use precomputed msa
......
......@@ -10,7 +10,7 @@ from tqdm import tqdm
import fastfold.habana as habana
from fastfold.config import model_config
from fastfold.data.data_modules import SetupTrainDataset, TrainDataLoader
from fastfold.habana.distributed import init_dist
from fastfold.habana.distributed import init_dist, get_data_parallel_world_size
from fastfold.habana.inject_habana import inject_habana
from fastfold.model.hub import AlphaFold, AlphaFoldLoss, AlphaFoldLRScheduler
from fastfold.utils.tensor_utils import tensor_tree_map
......@@ -156,7 +156,8 @@ def main():
model = inject_habana(model)
model = model.to(device="hpu")
model = DDP(model)
if get_data_parallel_world_size() > 1:
model = DDP(model, gradient_as_bucket_view=True, bucket_cap_mb=400)
train_dataset, test_dataset = SetupTrainDataset(
config=config.data,
......@@ -201,27 +202,32 @@ def main():
isVerbose=args.hmp_verbose)
print("========= HMP ENABLED!!")
idx = 0
for epoch in range(200):
model.train()
train_dataloader = tqdm(train_dataloader)
for batch in train_dataloader:
perf = hpu_perf("train step")
batch = {k: torch.as_tensor(v).to(device="hpu") for k, v in batch.items()}
batch = {k: torch.as_tensor(v).to(device="hpu", non_blocking=True) for k, v in batch.items()}
optimizer.zero_grad()
perf.checknow("prepare input and zero grad")
output = model(batch)
perf.checknow("forward")
batch = tensor_tree_map(lambda t: t[..., -1], batch)
perf.checknow("prepare loss input")
loss, loss_breakdown = criterion(output, batch, _return_breakdown=True)
perf.checknow("loss")
loss.backward()
train_dataloader.set_postfix(loss=float(loss))
if idx % 10 == 0:
train_dataloader.set_postfix(loss=float(loss))
perf.checknow("backward")
with hmp.disable_casts():
optimizer.step()
perf.checknow("optimizer")
idx += 1
lr_scheduler.step()
......
DATA_DIR=/mnt/usb/training-demo
export GC_KERNEL_PATH=./fastfold/habana/fastnn/custom_op/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
export PYTHONPATH=./:$PYTHONPATH
DATA_DIR=../FastFold-dataset/train
hpus_per_node=1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment