Unverified Commit 4dc719f7 authored by Leo Zhao's avatar Leo Zhao Committed by GitHub
Browse files

optimize code for bwd performance and refine code. (#145)

* optimize code for bwd performance and refine code.

* refine README to add test SW version
parent 95150c38
...@@ -210,11 +210,12 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \ ...@@ -210,11 +210,12 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
### Inference or Training on Intel Habana ### Inference or Training on Intel Habana
To run AlphaFold inference or training on Intel Habana, you can follow the instructions in the [Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/) to set up your environment on Amazon EC2 DL1 instances or on-premise environments. To run AlphaFold inference or training on Intel Habana, you can follow the instructions in the [Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/) to set up your environment on Amazon EC2 DL1 instances or on-premise environments, and please use SynapseAI R1.7.1 to test as it was verified internally.
Once you have prepared your dataset and installed fastfold, you can use the following scripts: Once you have prepared your dataset and installed fastfold, you can use the following scripts:
```shell ```shell
cd fastfold/habana/fastnn/custom_op/; python setup.py build (this is for Gaudi, for Gaudi2 please use setup2.py) ; cd -
bash habana/inference.sh bash habana/inference.sh
bash habana/train.sh bash habana/train.sh
``` ```
......
from .comm import (All_to_All, _gather, _reduce, _split, col_to_row, copy, from .comm import (All_to_All, _gather, _reduce, _split, col_to_row, copy,
gather, reduce, row_to_col, scatter) gather, reduce, row_to_col, scatter)
from .core import init_dist from .core import init_dist, get_data_parallel_world_size
__all__ = [ __all__ = [
'init_dist', '_reduce', '_split', '_gather', 'copy', 'scatter', 'reduce', 'gather', 'init_dist', 'get_data_parallel_world_size', '_reduce', '_split', '_gather', 'copy', 'scatter', 'reduce', 'gather',
'col_to_row', 'row_to_col', 'All_to_All' 'col_to_row', 'row_to_col', 'All_to_All'
] ]
...@@ -18,7 +18,7 @@ from typing import Tuple, Any, Sequence, Callable, Optional ...@@ -18,7 +18,7 @@ from typing import Tuple, Any, Sequence, Callable, Optional
import numpy as np import numpy as np
import torch import torch
import fastfold.habana as habana
def rot_matmul( def rot_matmul(
a: torch.Tensor, a: torch.Tensor,
...@@ -34,6 +34,19 @@ def rot_matmul( ...@@ -34,6 +34,19 @@ def rot_matmul(
Returns: Returns:
The product ab The product ab
""" """
if habana.is_habana():
if len(a.shape) == 4 and a.shape[1] == 1:
aa = a.permute(0, 1, 3, 2)
bb = b.permute(0, 1, 3, 2)
cc = bb @ aa
cc = cc.permute(0, 1, 3, 2)
return cc
elif len(a.shape) == 4 and a.shape[1] != 1:
pass
else:
cc = a @ b
return cc
row_1 = torch.stack( row_1 = torch.stack(
[ [
a[..., 0, 0] * b[..., 0, 0] a[..., 0, 0] * b[..., 0, 0]
...@@ -94,6 +107,20 @@ def rot_vec_mul( ...@@ -94,6 +107,20 @@ def rot_vec_mul(
Returns: Returns:
[*, 3] rotated coordinates [*, 3] rotated coordinates
""" """
if habana.is_habana():
cont = True
if len(t.shape) == 4 and t.shape[1] == 1:
cont = False
elif len(t.shape) == 3 and t.shape[0] != r.shape[0] and t.shape[0] == 1:
cont = False
if cont:
tt = t.unsqueeze(-2)
rr = r.transpose(-2, -1)
cc = tt @ rr
cc = cc.squeeze(-2)
return cc
x = t[..., 0] x = t[..., 0]
y = t[..., 1] y = t[..., 1]
z = t[..., 2] z = t[..., 2]
......
export GC_KERNEL_PATH=./fastfold/habana/fastnn/custom_op/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
export PYTHONPATH=./:$PYTHONPATH
# add '--gpus [N]' to use N gpus for inference # add '--gpus [N]' to use N gpus for inference
# add '--enable_workflow' to use parallel workflow for data processing # add '--enable_workflow' to use parallel workflow for data processing
# add '--use_precomputed_alignments [path_to_alignments]' to use precomputed msa # add '--use_precomputed_alignments [path_to_alignments]' to use precomputed msa
......
...@@ -10,7 +10,7 @@ from tqdm import tqdm ...@@ -10,7 +10,7 @@ from tqdm import tqdm
import fastfold.habana as habana import fastfold.habana as habana
from fastfold.config import model_config from fastfold.config import model_config
from fastfold.data.data_modules import SetupTrainDataset, TrainDataLoader from fastfold.data.data_modules import SetupTrainDataset, TrainDataLoader
from fastfold.habana.distributed import init_dist from fastfold.habana.distributed import init_dist, get_data_parallel_world_size
from fastfold.habana.inject_habana import inject_habana from fastfold.habana.inject_habana import inject_habana
from fastfold.model.hub import AlphaFold, AlphaFoldLoss, AlphaFoldLRScheduler from fastfold.model.hub import AlphaFold, AlphaFoldLoss, AlphaFoldLRScheduler
from fastfold.utils.tensor_utils import tensor_tree_map from fastfold.utils.tensor_utils import tensor_tree_map
...@@ -156,7 +156,8 @@ def main(): ...@@ -156,7 +156,8 @@ def main():
model = inject_habana(model) model = inject_habana(model)
model = model.to(device="hpu") model = model.to(device="hpu")
model = DDP(model) if get_data_parallel_world_size() > 1:
model = DDP(model, gradient_as_bucket_view=True, bucket_cap_mb=400)
train_dataset, test_dataset = SetupTrainDataset( train_dataset, test_dataset = SetupTrainDataset(
config=config.data, config=config.data,
...@@ -201,27 +202,32 @@ def main(): ...@@ -201,27 +202,32 @@ def main():
isVerbose=args.hmp_verbose) isVerbose=args.hmp_verbose)
print("========= HMP ENABLED!!") print("========= HMP ENABLED!!")
idx = 0
for epoch in range(200): for epoch in range(200):
model.train() model.train()
train_dataloader = tqdm(train_dataloader) train_dataloader = tqdm(train_dataloader)
for batch in train_dataloader: for batch in train_dataloader:
perf = hpu_perf("train step") perf = hpu_perf("train step")
batch = {k: torch.as_tensor(v).to(device="hpu") for k, v in batch.items()} batch = {k: torch.as_tensor(v).to(device="hpu", non_blocking=True) for k, v in batch.items()}
optimizer.zero_grad() optimizer.zero_grad()
perf.checknow("prepare input and zero grad")
output = model(batch) output = model(batch)
perf.checknow("forward") perf.checknow("forward")
batch = tensor_tree_map(lambda t: t[..., -1], batch) batch = tensor_tree_map(lambda t: t[..., -1], batch)
perf.checknow("prepare loss input")
loss, loss_breakdown = criterion(output, batch, _return_breakdown=True) loss, loss_breakdown = criterion(output, batch, _return_breakdown=True)
perf.checknow("loss") perf.checknow("loss")
loss.backward() loss.backward()
if idx % 10 == 0:
train_dataloader.set_postfix(loss=float(loss)) train_dataloader.set_postfix(loss=float(loss))
perf.checknow("backward") perf.checknow("backward")
with hmp.disable_casts(): with hmp.disable_casts():
optimizer.step() optimizer.step()
perf.checknow("optimizer") perf.checknow("optimizer")
idx += 1
lr_scheduler.step() lr_scheduler.step()
......
DATA_DIR=/mnt/usb/training-demo export GC_KERNEL_PATH=./fastfold/habana/fastnn/custom_op/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
export PYTHONPATH=./:$PYTHONPATH
DATA_DIR=../FastFold-dataset/train
hpus_per_node=1 hpus_per_node=1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment