Unverified Commit da5fe1a6 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

add support on habana platform (#131)



* add habana

* add mask

* fix mask in outer_product_mean

* add dap

* add hmp

* merge training code

* add chunk for inference

* fix extra-msa stack for training

* support ddp in training

* fix inference bugs

* code refactoring for habana

* support hmp training

* enable all inference and train on Gaudi/Gaudi2 with optimized perf with latest base (#139)

* enable all inference and train on Gaudi/Gaudi2 with optimized perf

* refine code to adapt new base

* refine code to fix issues in code review
Co-authored-by: default avatarhabanachina <habanachina@habana.ai>
Co-authored-by: default avatarLeo Zhao <48052473+LeoZhao-Habana@users.noreply.github.com>
Co-authored-by: default avatarhabanachina <habanachina@habana.ai>
parent e9db72d6
......@@ -4,10 +4,17 @@
[![](https://img.shields.io/badge/Paper-PDF-green?style=flat&logo=arXiv&logoColor=green)](https://arxiv.org/abs/2203.00854)
![](https://img.shields.io/badge/Made%20with-ColossalAI-blueviolet?style=flat)
![](https://img.shields.io/badge/Habana-support-blue?style=flat&logo=intel&logoColor=blue)
![](https://img.shields.io/github/v/release/hpcaitech/FastFold)
[![GitHub license](https://img.shields.io/github/license/hpcaitech/FastFold)](https://github.com/hpcaitech/FastFold/blob/main/LICENSE)
Optimizing Protein Structure Prediction Model Training and Inference on GPU Clusters
## News :triangular_flag_on_post:
- [2023/01] Compatible with AlphaFold v2.3
- [2023/01] Added support for inference and training of AlphaFold on [Intel Habana](https://habana.ai/) platform. For usage instructions, see [here](#Inference-or-Training-on-Intel-Habana).
<br>
Optimizing Protein Structure Prediction Model Training and Inference on Heterogeneous Clusters
FastFold provides a **high-performance implementation of Evoformer** with the following characteristics.
......@@ -201,6 +208,17 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
--kalign_binary_path `which kalign`
```
### Inference or Training on Intel Habana
To run AlphaFold inference or training on Intel Habana, you can follow the instructions in the [Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/) to set up your environment on Amazon EC2 DL1 instances or on-premise environments.
Once you have prepared your dataset and installed fastfold, you can use the following scripts:
```shell
bash habana/inference.sh
bash habana/train.sh
```
## Performance Benchmark
We have included a performance benchmark script in `./benchmark`. You can benchmark the performance of Evoformer using different settings.
......@@ -237,3 +255,7 @@ Cite this paper, if you use FastFold in your research publication.
primaryClass={cs.LG}
}
```
## Acknowledgments
We would like to extend our special thanks to the Intel Habana team for their support in providing us with technology and resources on the Habana platform.
......@@ -20,6 +20,7 @@ import ml_collections
import numpy as np
import torch
import fastfold.habana as habana
from fastfold.data import input_pipeline, input_pipeline_multimer
......@@ -91,19 +92,18 @@ def np_example_to_features(
np_example=np_example, features=feature_names
)
with torch.no_grad():
if is_multimer:
features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
else:
features = input_pipeline.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
if is_multimer:
input_pipeline_fn = input_pipeline_multimer.process_tensors_from_config
else:
input_pipeline_fn = input_pipeline.process_tensors_from_config
if habana.is_habana():
from habana_frameworks.torch.hpex import hmp
with torch.no_grad(), hmp.disable_casts():
features = input_pipeline_fn(tensor_dict, cfg.common, cfg[mode])
else:
with torch.no_grad():
features = input_pipeline_fn(tensor_dict, cfg.common, cfg[mode])
return {k: v for k, v in features.items()}
......@@ -118,7 +118,7 @@ class FeaturePipeline:
def process_features(
self,
raw_features: FeatureDict,
mode: str = "train",
mode: str = "train",
is_multimer: bool = False,
) -> FeatureDict:
return np_example_to_features(
......
ENABLE_HABANA = False
ENABLE_HMP = False
def enable_habana():
global ENABLE_HABANA
ENABLE_HABANA = True
global ENABLE_LAZY_MODE
ENABLE_LAZY_MODE = True
import habana_frameworks.torch.core
def is_habana():
global ENABLE_HABANA
return ENABLE_HABANA
def enable_hmp():
global ENABLE_HMP
ENABLE_HMP = True
def is_hmp():
global ENABLE_HMP
return ENABLE_HMP
\ No newline at end of file
from .comm import (All_to_All, _gather, _reduce, _split, col_to_row, copy,
gather, reduce, row_to_col, scatter)
from .core import init_dist
__all__ = [
'init_dist', '_reduce', '_split', '_gather', 'copy', 'scatter', 'reduce', 'gather',
'col_to_row', 'row_to_col', 'All_to_All'
]
from typing import Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from .core import (ensure_divisibility, get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
def divide(numerator, denominator):
ensure_divisibility(numerator, denominator)
return numerator // denominator
def _reduce(tensor: Tensor) -> Tensor:
if dist.get_world_size() == 1:
return tensor
dist.all_reduce(tensor,
op=dist.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
async_op=False)
return tensor
def _split(tensor: Tensor, dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor
split_size = divide(tensor.shape[dim], get_tensor_model_parallel_world_size())
tensor_list = torch.split(tensor, split_size, dim=dim)
output = tensor_list[get_tensor_model_parallel_rank()].contiguous()
return output
def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor
if dim == 1 and list(tensor.shape)[0] == 1:
output_shape = list(tensor.shape)
output_shape[1] *= get_tensor_model_parallel_world_size()
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
dist.all_gather(list(tensor_list),
tensor,
group=get_tensor_model_parallel_group(),
async_op=False)
else:
tensor_list = [
torch.empty_like(tensor) for _ in range(get_tensor_model_parallel_world_size())
]
dist.all_gather(tensor_list,
tensor,
group=get_tensor_model_parallel_group(),
async_op=False)
output = torch.cat(tensor_list, dim=dim)
return output
def copy(input: Tensor) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Copy.apply(input)
return input
class Copy(torch.autograd.Function):
@staticmethod
def forward(ctx: "Copy", input: Tensor) -> Tensor:
return input
@staticmethod
def backward(ctx: "Copy", grad_output: Tensor) -> Tensor:
return _reduce(grad_output)
def scatter(input: Tensor, dim: int = -1) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Scatter.apply(input, dim)
else:
input = _split(input, dim=dim)
return input
class Scatter(torch.autograd.Function):
@staticmethod
def forward(ctx: "Scatter", input: Tensor, dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([dim]))
return _split(input, dim=dim)
@staticmethod
def backward(ctx: "Scatter", grad_output: Tensor) -> Tuple[Tensor]:
dim, = ctx.saved_tensors
return _gather(grad_output, dim=int(dim)), None
def reduce(input: Tensor) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Reduce.apply(input)
else:
input = _reduce(input)
return input
class Reduce(torch.autograd.Function):
@staticmethod
def forward(ctx: "Reduce", input: Tensor) -> Tensor:
return _reduce(input)
@staticmethod
def backward(ctx: "Reduce", grad_output: Tensor) -> Tensor:
return grad_output
def gather(input: Tensor, dim: int = -1) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Gather.apply(input, dim)
else:
input = _gather(input, dim=dim)
return input
class Gather(torch.autograd.Function):
@staticmethod
def forward(ctx: "Gather", input: Tensor, dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([dim]))
return _gather(input, dim=dim)
@staticmethod
def backward(ctx: "Gather", grad_output: Tensor) -> Tuple[Tensor]:
dim, = ctx.saved_tensors
return _split(grad_output, dim=int(dim)), None
def _all_to_all(tensor: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
if dist.get_world_size() == 1:
return tensor
tensor = tensor.transpose(in_dim, 0).contiguous()
output = torch.empty_like(tensor)
dist.all_to_all_single(output, tensor, group=get_tensor_model_parallel_group())
output = output.transpose(in_dim, 0).contiguous()
tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=in_dim)
return torch.cat(tensor_list, dim=out_dim)
def col_to_row(input_: Tensor) -> Tensor:
if torch.is_grad_enabled() and input_.requires_grad:
input_ = All_to_All.apply(input_, 1, 2)
else:
input_ = _all_to_all(input_, in_dim=1, out_dim=2)
return input_
def row_to_col(input_: Tensor) -> Tensor:
if torch.is_grad_enabled() and input_.requires_grad:
input_ = All_to_All.apply(input_, 2, 1)
else:
input_ = _all_to_all(input_, in_dim=2, out_dim=1)
return input_
class All_to_All(torch.autograd.Function):
@staticmethod
def forward(ctx: "All_to_All", input_: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([in_dim, out_dim]))
return _all_to_all(input_, in_dim=in_dim, out_dim=out_dim)
@staticmethod
def backward(ctx: "All_to_All", grad_output: Tensor) -> Tuple[Tensor]:
saved_tensors = ctx.saved_tensors[0]
return _all_to_all(grad_output, in_dim=int(saved_tensors[1]),
out_dim=int(saved_tensors[0])), None, None
import os
import torch
import torch.distributed as dist
from mpi4py import MPI
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# These values enable us to change the mpu sizes on the fly.
_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_TENSOR_MODEL_PARALLEL_RANK = None
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator)
def set_missing_distributed_environ(key, value):
if key not in os.environ:
os.environ[str(key)] = str(value)
def init_dist(tensor_model_parallel_size_=1):
comm = MPI.COMM_WORLD
world_size = comm.Get_size()
rank = comm.Get_rank()
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12340'
import habana_frameworks.torch.distributed.hccl
dist.init_process_group(backend='hccl', rank=rank, world_size=world_size)
world_size = dist.get_world_size()
rank = dist.get_rank()
# check dist config
ensure_divisibility(world_size, tensor_model_parallel_size_)
data_parallel_size_ = world_size // tensor_model_parallel_size_
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized'
for i in range(tensor_model_parallel_size_):
ranks = range(i, world_size, tensor_model_parallel_size_)
group = dist.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
'tensor model parallel group is already initialized'
# Build the model-parallel groups.
for i in range(data_parallel_size_):
ranks = range(i * tensor_model_parallel_size_, (i + 1) * tensor_model_parallel_size_)
group = dist.new_group(ranks)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
if dist.get_rank() == 0:
print('> initialize tensor model parallel with size {}'.format(tensor_model_parallel_size_))
print('> initialize data parallel with size {}'.format(data_parallel_size_))
def dap_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
_DATA_PARALLEL_GROUP is None:
return False
return True
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
'intra_layer_model parallel group is not initialized'
return _TENSOR_MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _TENSOR_MODEL_PARALLEL_WORLD_SIZE
return dist.get_world_size(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _TENSOR_MODEL_PARALLEL_RANK
if _TENSOR_MODEL_PARALLEL_RANK is not None:
return _TENSOR_MODEL_PARALLEL_RANK
return dist.get_rank(group=get_tensor_model_parallel_group())
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return dist.get_world_size(group=get_data_parallel_group())
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return dist.get_rank(group=get_data_parallel_group())
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = dist.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
from functools import partial
from typing import Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from fastfold.habana.distributed import All_to_All, gather, scatter
from fastfold.utils.checkpointing import checkpoint_blocks
from .msa import ExtraMSACore, MSAStack
from .ops import Linear, OutProductMean
from .triangle import PairStack
import habana_frameworks.torch.core as htcore
class Evoformer(nn.Module):
def __init__(self,
c_m: int,
c_z: int,
first_block: bool,
last_block: bool,
is_multimer: bool = False):
super(Evoformer, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.msa = MSAStack(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair = PairStack(d_pair=c_z)
self.is_multimer = is_multimer
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = dist.get_world_size()
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
if self.is_multimer:
m = scatter(m, dim=2)
else:
m = scatter(m, dim=1)
z = scatter(z, dim=1)
# msa_mask = msa_mask.unsqueeze(0)
# pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
if not self.is_multimer:
m = self.msa(m, z, msa_mask)
z = self.communication(m, msa_mask, z)
m = All_to_All.apply(m, 1, 2)
z = self.pair(z, pair_mask)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m = All_to_All.apply(m, 1, 2)
z = self.pair(z, pair_mask)
m = self.msa(m, z_ori, msa_mask)
if self.last_block:
m = m.squeeze(0)
z = z.squeeze(0)
if self.is_multimer:
m = gather(m, dim=1)
else:
m = gather(m, dim=0)
z = gather(z, dim=0)
m = m[:, :-padding_size, :]
z = z[:-padding_size, :-padding_size, :]
htcore.mark_step()
return m, z
class EvoformerStack(nn.Module):
"""
Main Evoformer trunk.
Implements Algorithm 6.
"""
def __init__(
self,
c_m: int,
c_z: int,
c_s: int,
no_blocks: int,
blocks_per_ckpt: int,
clear_cache_between_blocks: bool = False,
is_multimer: bool = False,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair channel dimension
c_hidden_msa_att:
Hidden dimension in MSA attention
c_hidden_opm:
Hidden dimension in outer product mean module
c_hidden_mul:
Hidden dimension in multiplicative updates
c_hidden_pair_att:
Hidden dimension in triangular attention
c_s:
Channel dimension of the output "single" embedding
no_heads_msa:
Number of heads used for MSA attention
no_heads_pair:
Number of heads used for pair attention
no_blocks:
Number of Evoformer blocks in the stack
transition_n:
Factor by which to multiply c_m to obtain the MSATransition
hidden dimension
msa_dropout:
Dropout rate for MSA activations
pair_dropout:
Dropout used for pair activations
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
"""
super(EvoformerStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks = nn.ModuleList()
for block_id in range(no_blocks):
block = Evoformer(
c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == no_blocks - 1),
is_multimer=is_multimer,
)
self.blocks.append(block)
self.linear = Linear(c_m, c_s)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
) for b in self.blocks
]
if torch.is_grad_enabled():
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
else:
for b in blocks:
m, z = b(m, z)
s = self.linear(m[..., 0, :, :])
htcore.mark_step()
return m, z, s
class ExtraMSABlock(nn.Module):
def __init__(self,
c_m: int,
c_z: int,
first_block: bool,
last_block: bool,
is_multimer: bool = False):
super(ExtraMSABlock, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.msa_stack = ExtraMSACore(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair_stack = PairStack(d_pair=c_z)
self.is_multimer = is_multimer
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
htcore.mark_step()
dap_size = dist.get_world_size()
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
if self.is_multimer:
m = scatter(m, dim=2)
else:
m = scatter(m, dim=1)
z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
if not self.is_multimer:
m = self.msa_stack(m, z, msa_mask)
z = self.communication(m, msa_mask, z)
m = All_to_All.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m = All_to_All.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = self.msa_stack(m, z_ori, msa_mask)
if self.last_block:
m = m.squeeze(0)
z = z.squeeze(0)
if self.is_multimer:
m = gather(m, dim=1)
else:
m = gather(m, dim=0)
z = gather(z, dim=0)
m = m[:, :-padding_size, :]
z = z[:-padding_size, :-padding_size, :]
htcore.mark_step()
return m, z
class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
"""
def __init__(
self,
c_m: int,
c_z: int,
no_blocks: int,
blocks_per_ckpt: int,
clear_cache_between_blocks: bool = False,
is_multimer: bool = False,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self.blocks = nn.ModuleList()
for block_id in range(no_blocks):
block = ExtraMSABlock(
c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == no_blocks - 1),
is_multimer=is_multimer,
)
self.blocks.append(block)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
) for b in self.blocks
]
if torch.is_grad_enabled():
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
else:
for b in blocks:
m, z = b(m, z)
return z
# CustomOp API Usage in PyTorch
This README provides an example of how to write custom PyTorch Ops using a TPC Kernel supported on an HPU device. For more details, refer to [PyTorch CustomOP API](https://docs.habana.ai/en/latest/PyTorch/PyTorch_CustomOp_API/page_index.html) documentation.
For further information on training deep learning models using Gaudi, refer to [developer.habana.ai](https://developer.habana.ai/resources/).
## Table of Contents
* [Model-References](../../../README.md)
* [Prerequisites](#prerequisites)
* [Content](#content)
* [Build and Run with Custom Kernels](#build-and-run-with-custom-kernels)
* [Important to Know](#important-to-know)
* [Applying CustomOps to a Real Training Model Example](#applying-customops-to-a-real-training-model-example)
* [Known Issues](#known-issues)
## Prerequisites
- A TPC kernel on which the HpuKernel will run. To write a CustomOp, you must define the TPC kernel that HpuKernel will run on first. This document provides the required steps for using the existing default TPC kernels `relu_fwd_f32`, `relu_bwd_f32` as we all as the custom kernel `custom_op::custom_relu` to implement CustomOp. For further information on how to write TPC kernels, refer to the [Habana Custom Kernel GitHub page](https://github.com/HabanaAI/Habana_Custom_Kernel).
- **habana-torch-plugin** Python package must be installed. Make sure to install by following the instructions detailed in the [Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html).
## Content
- C++ file with **custom_op::fusedsoftmax**, **custom_op::fusedsoftmax_bias** definition and Kernel implementation on HPU:
- `fusedsoftmax` performs a fused softmax on input and mask.
- `fusedsoftmax_bias` performs a fused softmax on input, mask and bias
- `setup.py` file for building the solution:
- To compile to Op on Gaudi, run ```python setup.py build```.
- To compile to Op on Gaudi2, run ```python setup2.py build```.
- Python test to run and validate `fusedsoftmax` and `fusedsoftmax_bias`:
- ```python hpu_fusedsoftmax_test.py```
## Build and Run with Custom Kernels
To build and run `fused_softmax` and `fusedsoftmax_bias`, run the following:
```python setup.py build```
## Important to Know
This is an example of an Op implementing both forward and backward.
The forward and backward CustomOp is used for training the model by extending the [torch.autograd](https://pytorch.org/docs/stable/notes/extending.html) package.
## Known Issues
BF16 or HMP is not supported yet. To use CustomOp in topology, run FP32 variant only.
###############################################################################
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
###############################################################################
from .fusedsoftmax import fused_softmax, fused_softmax_bias
__all__ = [fused_softmax, fused_softmax_bias]
###############################################################################
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
###############################################################################
import torch
import os
import habana_frameworks.torch.core
custom_fusedsoftmax_op_lib_path = "./build/lib.linux-x86_64-3.8/hpu_fusedsoftmax.cpython-38-x86_64-linux-gnu.so"
my_dir = os.path.realpath(__file__)
my_len = my_dir.rfind('/')
base_dir = my_dir[:my_len]
torch.ops.load_library(os.path.join(base_dir, custom_fusedsoftmax_op_lib_path))
class FusedSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, dim):
# ctx is a context object that can be used to stash information
# for backward computation
tensor = torch.ops.custom_op.fusedsoftmax(input, mask, dim)
ctx.y = tensor
ctx.dim = dim
return tensor
@staticmethod
def backward(ctx, grad_output):
if grad_output is None:
return None
y = ctx.y
ctx.y = None
dim = ctx.dim
ctx.dim = None
grad_input = torch.ops.custom_op.fusedsoftmax_backward(y, grad_output, dim)
return grad_input, None, None
class FusedSoftmaxBiasFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, bias, dim):
# ctx is a context object that can be used to stash information
# for backward computation
tensor = torch.ops.custom_op.fusedsoftmax_bias(input, mask, bias, dim)
ctx.y = tensor
ctx.dim = dim
ctx.use_bias = False
if bias is not None:
ctx.use_bias = True
return tensor
@staticmethod
def backward(ctx, grad_output):
if grad_output is None:
return None
y = ctx.y
ctx.y = None
dim = ctx.dim
ctx.dim = None
grad_input = torch.ops.custom_op.fusedsoftmax_backward(y, grad_output, dim)
grad_bias = None
if ctx.use_bias:
grad_bias = torch.sum(grad_input, dim=-4, keepdim=True)
return grad_input, None, grad_bias, None
ENABLE_OPT = True
def fused_softmax(input, mask, dim):
if ENABLE_OPT and input[..., :, :1, :1, :].shape == mask.shape:
return FusedSoftmaxFunction.apply(input, mask, dim)
else:
input += mask
return torch.softmax(input, dim=dim)
def fused_softmax_bias(input, mask, bias, dim):
if ENABLE_OPT and input[..., :, :1, :1, :].shape == mask.shape and input[..., :1, :, :, :].shape == bias.shape:
return FusedSoftmaxBiasFunction.apply(input, mask, bias, dim)
else:
input += mask
input += bias
return torch.softmax(input, dim=dim)
/******************************************************************************
###############################################################################
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
###############################################################################
*******************************************************************************/
#include "hpu_custom_op.h"
#include <torch/extension.h>
#include <perf_lib_layer_params.h>
struct SoftMaxParam
{
int32_t axis;
bool with_bias;
};
bool register_fusedsoftmax() {
// Registering custom_op::fusedsoftmax
// inputs desc
habana::custom_op::InputDesc input_a_desc{
habana::custom_op::input_type::TENSOR, 0};
habana::custom_op::InputDesc input_b_desc{
habana::custom_op::input_type::TENSOR, 1};
habana::custom_op::InputDesc input_d_desc{
habana::custom_op::input_type::USER_PARAMS, 2};
std::vector<habana::custom_op::InputDesc> inputs_desc{
input_a_desc, input_b_desc, input_d_desc};
// output desc
// output shape callback
auto output_size_lambda =
[](const at::Stack& inputs) -> std::vector<int64_t> {
auto self = inputs[0].toTensor(); // input
std::vector<int64_t> result_sizes = self.sizes().vec();
return result_sizes;
};
habana::custom_op::OutputDesc output_desc{
0, c10::ScalarType::Float, output_size_lambda};
std::vector<habana::custom_op::OutputDesc> outputs_desc{
output_desc};
// user param callback
auto user_params_lambda = [](const at::Stack& inputs, size_t& size) {
HPU_PARAMS_STUB(SoftMaxParam);
params->with_bias = false;
int dim = inputs[2].toInt();
if (dim > 0)
params->axis = inputs[0].toTensor().dim() - dim - 1;
else
params->axis = - dim - 1;
return params;
};
// actual register
REGISTER_CUSTOM_OP_ATTRIBUTES(
"custom_op::fusedsoftmax", //schema name
#ifdef GAUDI2
"fusedsoftmax_fwd_f32_gaudi2", // guid
#else
"fusedsoftmax_fwd_f32", // guid
#endif
inputs_desc,
outputs_desc,
user_params_lambda);
std::cout << "cpp registered custom_op::fusedsoftmax\n";
return true;
}
bool register_fusedsoftmax_bias() {
// Registering custom_op::fusedsoftmax
// inputs desc
habana::custom_op::InputDesc input_a_desc{
habana::custom_op::input_type::TENSOR, 0};
habana::custom_op::InputDesc input_b_desc{
habana::custom_op::input_type::TENSOR, 1};
habana::custom_op::InputDesc input_c_desc{
habana::custom_op::input_type::TENSOR, 2};
habana::custom_op::InputDesc input_d_desc{
habana::custom_op::input_type::USER_PARAMS, 3};
std::vector<habana::custom_op::InputDesc> inputs_desc{
input_a_desc, input_b_desc, input_c_desc, input_d_desc};
// output desc
// output shape callback
auto output_size_lambda =
[](const at::Stack& inputs) -> std::vector<int64_t> {
auto self = inputs[0].toTensor(); // input
std::vector<int64_t> result_sizes = self.sizes().vec();
return result_sizes;
};
habana::custom_op::OutputDesc output_desc{
0, c10::ScalarType::Float, output_size_lambda};
std::vector<habana::custom_op::OutputDesc> outputs_desc{
output_desc};
// user param callback
auto user_params_lambda = [](const at::Stack& inputs, size_t& size) {
HPU_PARAMS_STUB(SoftMaxParam);
params->with_bias = true;
int dim = inputs[3].toInt();
if (dim > 0)
params->axis = inputs[0].toTensor().dim() - dim - 1;
else
params->axis = - dim - 1;
return params;
};
// actual register
REGISTER_CUSTOM_OP_ATTRIBUTES(
"custom_op::fusedsoftmax_bias", //schema name
#ifdef GAUDI2
"fusedsoftmax_bias_fwd_f32_gaudi2", // guid
#else
"fusedsoftmax_bias_fwd_f32", // guid
#endif
inputs_desc,
outputs_desc,
user_params_lambda);
std::cout << "cpp registered custom_op::fusedsoftmax_bias\n";
return true;
}
bool register_custom_fusedsoftmax_backward() {
// inputs desc
habana::custom_op::InputDesc y_desc{
habana::custom_op::input_type::TENSOR, 0};
habana::custom_op::InputDesc grad_desc{
habana::custom_op::input_type::TENSOR, 1};
habana::custom_op::InputDesc dim_desc{
habana::custom_op::input_type::USER_PARAMS, 2};
std::vector<habana::custom_op::InputDesc> inputs_desc{
y_desc, grad_desc, dim_desc};
auto output_input_size_lambda =
[](const at::Stack& inputs) -> std::vector<int64_t> {
auto self = inputs[0].toTensor(); // input
std::vector<int64_t> result_sizes = self.sizes().vec();
return result_sizes;
};
habana::custom_op::OutputDesc input_grad_desc{
0, c10::ScalarType::Float, output_input_size_lambda};
std::vector<habana::custom_op::OutputDesc> outputs_desc{
input_grad_desc};
// user param callback
auto user_params_lambda = [](const at::Stack& inputs, size_t& size) {
HPU_PARAMS_STUB(ns_Softmax::Params);
params->dim = 0;
return params;
};
// actual register
REGISTER_CUSTOM_OP_ATTRIBUTES(
"custom_op::fusedsoftmax_backward", //schema name
#ifdef GAUDI2
"softmax_bwd_f32", // guid
#else
"softmax_bwd_f32", // guid
#endif
inputs_desc,
outputs_desc,
user_params_lambda);
std::cout << "cpp registered custom_op::fusedsoftmax_backward\n";
return true;
}
at::Tensor fusedsoftmax_execute(
torch::Tensor input,
torch::Tensor mask,
at::Scalar dim) {
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Float, "Input input_a expected to be Float tensor");
// Registering the custom op, need to be called only once
static bool registered = register_fusedsoftmax();
TORCH_CHECK(registered, "fusedsoftmax kernel not registered" );
std::vector<c10::IValue> inputs{input, mask, dim};
// Get custom op descriptor from registry
auto op_desc = habana::custom_op::HabanaCustomOpDescriptor::getCustomOpDescriptor("custom_op::fusedsoftmax");
// Actual call for op execution
std::vector<at::Tensor> output = op_desc.execute(inputs);
// op_desc.execute will always return a vector
return output[0];
}
at::Tensor fusedsoftmax_bias_execute(
torch::Tensor input,
torch::Tensor mask,
torch::Tensor bias,
at::Scalar dim) {
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Float, "Input input_a expected to be Float tensor");
// Registering the custom op, need to be called only once
static bool registered = register_fusedsoftmax_bias();
TORCH_CHECK(registered, "fusedsoftmax_bias kernel not registered" );
std::vector<c10::IValue> inputs{input, mask, bias, dim};
// Get custom op descriptor from registry
auto op_desc = habana::custom_op::HabanaCustomOpDescriptor::getCustomOpDescriptor("custom_op::fusedsoftmax_bias");
// Actual call for op execution
std::vector<at::Tensor> output = op_desc.execute(inputs);
// op_desc.execute will always return a vector
return output[0];
}
at::Tensor fusedsoftmax_backward_execute(
torch::Tensor y,
torch::Tensor grad,
at::Scalar dim) {
TORCH_CHECK(y.scalar_type() == c10::ScalarType::Float, "Input y expected to be Float tensor");
TORCH_CHECK(grad.scalar_type() == c10::ScalarType::Float, "Input grad expected to be Float tensor");
// Registering the custom op, need to be called only once
static bool registered = register_custom_fusedsoftmax_backward();
TORCH_CHECK(registered, "custom_fusedsoftmax_backward kernel not registered" );
std::vector<c10::IValue> inputs{y, grad, dim};
// Get custom op descriptor from registry
auto op_desc = habana::custom_op::HabanaCustomOpDescriptor::getCustomOpDescriptor("custom_op::fusedsoftmax_backward");
// Actual call for op execution
std::vector<at::Tensor> output = op_desc.execute(inputs);
// op_desc.execute will always return a vector
return output[0];
}
TORCH_LIBRARY(custom_op, m) {
m.def("fusedsoftmax(Tensor self, Tensor mask, Scalar dim) -> Tensor");
m.def("fusedsoftmax_bias(Tensor self, Tensor mask, Tensor bias, Scalar dim) -> Tensor");
m.def("fusedsoftmax_backward(Tensor y, Tensor grad, Scalar dim) -> Tensor");
}
TORCH_LIBRARY_IMPL(custom_op, HPU, m) {
m.impl("fusedsoftmax", fusedsoftmax_execute);
m.impl("fusedsoftmax_bias", fusedsoftmax_bias_execute);
m.impl("fusedsoftmax_backward", fusedsoftmax_backward_execute);
}
###############################################################################
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
###############################################################################
import torch
from fusedsoftmax import fused_softmax, fused_softmax_bias
def test_fusedsoftmax_op_function():
print(torch.ops.custom_op.fusedsoftmax)
print(torch.ops.custom_op.fusedsoftmax_bias)
# print(torch.ops.custom_op.custom_relu_backward)
input = torch.randn(1, 512, 4, 512, 512)
mask = torch.randn(1, 512, 1, 1, 512)
bias = torch.randn(1, 1, 4, 512, 512)
dim = -1
input_hpu = input.to('hpu')
mask_hpu = mask.to('hpu')
out = input + mask
output_cpu = torch.softmax(out, dim=dim)
output_hpu = fused_softmax(input_hpu, mask_hpu, dim)
assert((abs(output_hpu.cpu() - output_cpu) < 1e-6).all())
print("fused_softmax test passed")
input_hpu = input.to('hpu')
mask_hpu = mask.to('hpu')
bias_hpu = bias.to('hpu')
out = input + mask
out += bias
output_cpu = torch.softmax(out, dim=dim)
output_hpu = fused_softmax_bias(input_hpu, mask_hpu, bias_hpu, dim);
assert((abs(output_hpu.cpu() - output_cpu) < 1e-6).all())
print("fused_softmax_bias test passed")
test_fusedsoftmax_op_function()
def test_fusedsoftmax_bias_op_backward_function():
print("fused_softmax_bias_backward")
input = torch.randn(1, 512, 4, 512, 512, requires_grad=True)
mask = torch.randn(1, 512, 1, 1, 512, requires_grad=False)
bias = torch.randn(1, 1, 4, 512, 512, requires_grad=True)
dim = -1
# cpu reference
add_mask_cpu = input + mask
add_mask_cpu += bias
softmax_cpu = torch.softmax(add_mask_cpu, dim=dim)
input_hpu = input.to('hpu').detach()
input_hpu.requires_grad = True
mask_hpu = mask.to('hpu').detach()
mask_hpu.requires_grad = False
bias_hpu = bias.to('hpu').detach()
bias_hpu.requires_grad = True
softmax_hpu = fused_softmax_bias(input_hpu, mask_hpu, bias_hpu, dim)
assert((abs(softmax_hpu.detach().cpu() - softmax_cpu.detach()) < 1e-6).all())
grad_cpu = torch.ones_like(softmax_cpu)
softmax_cpu.backward(grad_cpu)
grad_hpu = grad_cpu.to('hpu')
softmax_hpu.backward(grad_hpu)
input_bwd_cpu = input.grad
input_bwd_hpu = input_hpu.grad
assert((abs(input_bwd_hpu.detach().cpu() - input_bwd_cpu.detach()) < 1e-6).all())
bias_bwd_cpu = bias.grad
bias_bwd_hpu = bias_hpu.grad
assert((abs(bias_bwd_hpu.detach().cpu() - bias_bwd_cpu.detach()) < 1e-6).all())
print("fused_softmax_bias_backward test passed")
test_fusedsoftmax_bias_op_backward_function()
def test_fusedsoftmax_op_backward_function():
print(torch.ops.custom_op.fusedsoftmax_backward)
input = torch.randn(1, 512, 4, 512, 512, requires_grad=True)
mask = torch.randn(1, 512, 1, 1, 512, requires_grad=False)
dim = -1
# cpu reference
add_mask_cpu = input + mask
softmax_cpu = torch.softmax(add_mask_cpu, dim=dim)
input_hpu = input.to('hpu').detach()
input_hpu.requires_grad = True
mask_hpu = mask.to('hpu').detach()
mask_hpu.requires_grad = False
softmax_hpu = fused_softmax(input_hpu, mask_hpu, dim)
assert((abs(softmax_hpu.detach().cpu() - softmax_cpu.detach()) < 1e-6).all())
grad_cpu = torch.ones_like(softmax_cpu)
softmax_cpu.backward(grad_cpu)
grad_hpu = grad_cpu.to('hpu')
softmax_hpu.backward(grad_hpu)
input_bwd_cpu = input.grad
input_bwd_hpu = input_hpu.grad
assert((abs(input_bwd_hpu.detach().cpu() - input_bwd_cpu.detach()) < 1e-6).all())
print("fused_softmax_backward test passed")
test_fusedsoftmax_op_backward_function()
###############################################################################
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
###############################################################################
from setuptools import setup
from torch.utils import cpp_extension
from habana_frameworks.torch.utils.lib_utils import get_include_dir, get_lib_dir
import os
import pybind11
torch_include_dir = get_include_dir()
torch_lib_dir = get_lib_dir()
habana_modules_directory = "/usr/include/habanalabs"
pybind_include_path = pybind11.get_include()
setup(name='hpu_fusedsoftmax',
ext_modules=[cpp_extension.CppExtension('hpu_fusedsoftmax', ['hpu_fusedsoftmax.cpp'],
language='c++', extra_compile_args=["-std=c++17"],
libraries=['habana_pytorch_plugin'],
library_dirs=[torch_lib_dir])],
include_dirs=[torch_include_dir,
habana_modules_directory,
pybind_include_path,
],
cmdclass={'build_ext': cpp_extension.BuildExtension})
###############################################################################
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
###############################################################################
from setuptools import setup
from torch.utils import cpp_extension
from habana_frameworks.torch.utils.lib_utils import get_include_dir, get_lib_dir
import os
import pybind11
torch_include_dir = get_include_dir()
torch_lib_dir = get_lib_dir()
habana_modules_directory = "/usr/include/habanalabs"
pybind_include_path = pybind11.get_include()
setup(name='hpu_fusedsoftmax',
ext_modules=[cpp_extension.CppExtension('hpu_fusedsoftmax', ['hpu_fusedsoftmax.cpp'],
language='c++', extra_compile_args=["-std=c++17"], define_macros=[("GAUDI2", None)],
libraries=['habana_pytorch_plugin'],
library_dirs=[torch_lib_dir])],
include_dirs=[torch_include_dir,
habana_modules_directory,
pybind_include_path,
],
cmdclass={'build_ext': cpp_extension.BuildExtension})
import math
import numpy as np
import torch.nn as nn
def glorot_uniform_af(x, gain=1.0):
"""
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
In PyTorch:
[feature_out, feature_in, n_head ...]
In Jax:
[... n_head, feature_in, feature_out]
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
[feature_in, n_head, feature_out]
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
"""
fan_in, fan_out = x.shape[-2:]
if len(x.shape) > 2:
receptive_field_size = np.prod(x.shape[:-2])
fan_in *= receptive_field_size
fan_out *= receptive_field_size
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
nn.init.uniform_(x, -dev, dev)
return x
import torch
import torch.nn.functional as F
def bias_sigmod_ele(y, bias, z):
return torch.sigmoid(y + bias) * z
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
out = (x + bias) * F.dropout(dropmask, p=prob, training=training)
out = residual + out
return out
def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
dropout_mask: torch.Tensor, Z_raw: torch.Tensor, prob: float,
training: bool) -> torch.Tensor:
return Z_raw + F.dropout(dropout_mask, p=prob, training=training) * (g * (ab + b))
\ No newline at end of file
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn import LayerNorm
from fastfold.habana.distributed import gather, row_to_col, scatter
from .kernel import bias_dropout_add
from .ops import GlobalAttention, SelfAttention, Transition
class MSAColumnGlobalAttention(nn.Module):
def __init__(self, d_node, c=8, n_head=8):
super(MSAColumnGlobalAttention, self).__init__()
self.d_node = d_node
self.c = c
self.n_head = n_head
self.layernormM = LayerNorm(d_node)
self.global_attention = GlobalAttention(qkv_dim=d_node, c=c, n_head=n_head, out_dim=d_node)
def forward(self, M_raw, M_mask):
M = M_raw.transpose(-2, -3)
M = self.layernormM(M)
M_mask = M_mask.transpose(-1, -2)
M = self.global_attention(M, M_mask)
M = M.transpose(-2, -3)
return M_raw + M
class MSARowAttentionWithPairBias(nn.Module):
def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
super(MSARowAttentionWithPairBias, self).__init__()
self.d_node = d_node
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernormM = LayerNorm(d_node)
self.layernormZ = LayerNorm(d_pair)
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
std=1.0 / math.sqrt(d_pair))
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)
self.attention = SelfAttention(qkv_dim=d_node,
c=c,
n_head=n_head,
out_dim=d_node,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)
def forward(self, M_raw, Z, M_mask):
## Input projections
M = self.layernormM(M_raw)
Z = self.layernormZ(Z)
b = F.linear(Z, self.linear_b_weights)
b = gather(b, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
M = self.attention(M, M_mask, b)
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
return bias_dropout_add(M,
self.out_bias,
dropout_mask,
M_raw,
prob=self.p_drop,
training=self.training)
class MSAColumnAttention(nn.Module):
def __init__(self, d_node, c=32, n_head=8):
super(MSAColumnAttention, self).__init__()
self.d_node = d_node
self.c = c
self.n_head = n_head
self.layernormM = LayerNorm(d_node)
self.attention = SelfAttention(qkv_dim=d_node,
c=c,
n_head=n_head,
out_dim=d_node,
gating=True)
def forward(self, M_raw, M_mask):
M = M_raw.transpose(-2, -3)
M = self.layernormM(M)
M_mask = M_mask.transpose(-1, -2)
M = self.attention(M, M_mask)
M = M.transpose(-2, -3)
return M_raw + M
class MSAStack(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15):
super(MSAStack, self).__init__()
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
d_pair=d_pair,
p_drop=p_drop)
self.MSAColumnAttention = MSAColumnAttention(d_node=d_node)
self.MSATransition = Transition(d=d_node)
def forward(self, node, pair, node_mask):
node_mask_row = scatter(node_mask, dim=1)
node = self.MSARowAttentionWithPairBias(node, pair, node_mask_row)
node = row_to_col(node)
node_mask_col = scatter(node_mask, dim=2)
node = self.MSAColumnAttention(node, node_mask_col)
node = self.MSATransition(node)
return node
class ExtraMSACore(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15):
super(ExtraMSACore, self).__init__()
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
d_pair=d_pair,
p_drop=p_drop,
c=8)
self.MSAColumnAttention = MSAColumnGlobalAttention(d_node=d_node, c=8)
self.MSATransition = Transition(d=d_node)
def forward(self, node, pair, node_mask):
node_mask_row = scatter(node_mask, dim=1)
node = self.MSARowAttentionWithPairBias(node, pair, node_mask_row)
node = row_to_col(node)
node_mask_col = scatter(node_mask, dim=2)
node = self.MSAColumnAttention(node, node_mask_col)
node = self.MSATransition(node)
return node
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn import LayerNorm
from fastfold.habana.distributed import gather, scatter
from .initializer import glorot_uniform_af
from .kernel import bias_sigmod_ele
from fastfold.habana.distributed import gather, scatter
from fastfold.habana.fastnn.custom_op import fused_softmax, fused_softmax_bias
CHUNK_SIZE = None
DEBUG = False
def set_chunk_size(chunk_size):
global CHUNK_SIZE
CHUNK_SIZE = chunk_size
def get_chunk_size():
global CHUNK_SIZE
return CHUNK_SIZE
class DropoutRowwise(nn.Module):
def __init__(self, p):
super(DropoutRowwise, self).__init__()
self.p = p
self.dropout = nn.Dropout(p=p)
def forward(self, x):
dropout_mask = torch.ones_like(x[:, 0:1, :, :])
dropout_mask = self.dropout(dropout_mask)
return dropout_mask * x
class DropoutColumnwise(nn.Module):
def __init__(self, p):
super(DropoutColumnwise, self).__init__()
self.p = p
self.dropout = nn.Dropout(p=p)
def forward(self, x):
dropout_mask = torch.ones_like(x[:, :, 0:1, :])
dropout_mask = self.dropout(dropout_mask)
return dropout_mask * x
class Transition(nn.Module):
def __init__(self, d, n=4):
super(Transition, self).__init__()
self.norm = LayerNorm(d)
self.linear1 = Linear(d, n * d, initializer='relu')
self.linear2 = Linear(n * d, d, initializer='zeros')
def forward(self, src):
x = self.norm(src)
x = self.linear2(F.relu(self.linear1(x)))
return src + x
class OutProductMean(nn.Module):
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32):
super(OutProductMean, self).__init__()
self.layernormM = LayerNorm(n_feat)
self.linear_a = Linear(n_feat, n_feat_proj)
self.linear_b = Linear(n_feat, n_feat_proj)
self.o_linear = Linear(n_feat_proj * n_feat_proj,
n_feat_out,
initializer='zero',
use_bias=True)
def forward(self, M, M_mask, Z_raw):
Z = torch.empty_like(Z_raw)
M = self.layernormM(M)
left_act = self.linear_a(M)
right_act = self.linear_b(M)
right_act_all = gather(right_act, dim=2)
M_mask = M_mask.unsqueeze(-1)
M_mask_col = scatter(M_mask, dim=2)
left_act = M_mask_col * left_act
right_act_all = M_mask * right_act_all
norm = torch.einsum('...ab,...ad->...bd',
M_mask_col.squeeze(-1).squeeze(0),
M_mask.squeeze(-1).squeeze(0)).unsqueeze(-1).unsqueeze(0)
para_dim = left_act.shape[2]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
out = []
for ax in range(0, para_dim, chunk_size):
left_act_part = left_act[:, :, ax:ax + chunk_size, :]
# O = torch.einsum('sid,sje->ijde', left_act_part.squeeze(0), right_act_all.squeeze(0))
# O = rearrange(O, 'i j d e -> i j (d e)')
left_shape = left_act_part.shape
right_shape = right_act_all.shape
left_act_part = left_act_part.reshape(left_shape[0], left_shape[1], left_shape[2]*left_shape[3])
right_act_all = right_act_all.reshape(right_shape[0], right_shape[1], right_shape[2]*right_shape[3])
# O = torch.einsum('...ab,...ad->...bd', left_act_part.squeeze(0), right_act_all.squeeze(0))
O = torch.matmul(left_act_part.squeeze(0).transpose(1, 0), right_act_all.squeeze(0))
O = O.reshape(left_shape[2], left_shape[3], right_shape[2], right_shape[3]).transpose(-2, -3)
O = O.reshape(O.shape[0], O.shape[1], O.shape[2]*O.shape[3])
O = O.unsqueeze(0)
out.append(self.o_linear(O))
Z = torch.cat(out, dim=1)
Z /= (1e-3 + norm)
return Z + Z_raw
class Linear(nn.Linear):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def __init__(
self,
feature_in: int,
feature_out: int,
initializer: str = 'linear',
use_bias: bool = True,
bias_init: float = 0.,
):
super(Linear, self).__init__(feature_in, feature_out, bias=use_bias)
self.use_bias = use_bias
if initializer == 'linear':
glorot_uniform_af(self.weight, gain=1.0)
elif initializer == 'relu':
glorot_uniform_af(self.weight, gain=2.0)
elif initializer == 'zeros':
nn.init.zeros_(self.weight)
if self.use_bias:
with torch.no_grad():
self.bias.fill_(bias_init)
class SelfAttention(nn.Module):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False):
super(SelfAttention, self).__init__()
self.qkv_dim = qkv_dim
self.c = c
self.n_head = n_head
self.out_dim = out_dim
self.gating = gating
self.last_bias_fuse = last_bias_fuse
self.scaling = self.c**(-0.5)
self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear', use_bias=False)
# self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
# self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
# self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
if gating:
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False)
self.o_linear = Linear(n_head * c,
out_dim,
initializer='zero',
use_bias=(not last_bias_fuse))
def forward(self, in_data, mask, nonbatched_bias=None):
"""
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
para_dim = in_data.shape[1]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
output = []
for ax in range(0, para_dim, chunk_size):
in_data_part = in_data[:, ax:ax + chunk_size, :, :]
mask_part = mask[:, ax:ax + chunk_size, :]
qkv = self.to_qkv(in_data_part).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
# q = self.to_q(in_data_part)
# k = self.to_k(in_data_part)
# v = self.to_v(in_data_part)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
# [q, k, v])
q = q * self.scaling
logits = torch.matmul(q, k.transpose(-1, -2))
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# if nonbatched_bias is not None:
# logits += nonbatched_bias.unsqueeze(1)
# weights = torch.softmax(logits, dim=-1)
mask00 = (1e9 * (mask_part - 1))[..., :, None, None, :]
if nonbatched_bias is not None:
weights = fused_softmax_bias(logits, mask00, nonbatched_bias.unsqueeze(1), -1)
else:
weights = fused_softmax(logits, mask00, -1)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
if self.gating:
gate_values = self.gating_linear(in_data_part)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
output.append(self.o_linear(weighted_avg))
output = torch.cat(output, dim=1)
return output
class GlobalAttention(nn.Module):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def __init__(self, qkv_dim, c, n_head, out_dim):
super(GlobalAttention, self).__init__()
self.qkv_dim = qkv_dim
self.c = c
self.n_head = n_head
self.out_dim = out_dim
self.scaling = self.c**(-0.5)
self.eps = 1e-10
self.inf = 1e9
self.to_q = Linear(qkv_dim, c * self.n_head, use_bias=False)
self.to_kv = Linear(qkv_dim, 2 * c, initializer="linear", use_bias=False)
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
self.gating_linear = Linear(qkv_dim, n_head * c, initializer="zero", use_bias=False)
self.o_linear = Linear(n_head * c, out_dim, initializer="zero")
def forward(self, m, mask):
para_dim = m.shape[1]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
output = []
for ax in range(0, para_dim, chunk_size):
m_part = m[:, ax:ax + chunk_size, :, :]
mask_part = mask[:, ax:ax + chunk_size, :]
q = torch.sum(m_part * mask_part.unsqueeze(-1),
dim=-2) / (torch.sum(mask_part, dim=-1)[..., None] + self.eps)
q = self.to_q(q)
q = q.view(q.shape[:-1] + (self.n_head, -1))
k, v = self.to_kv(m_part).chunk(2, dim=-1)
logits = torch.matmul(q, k.transpose(-1, -2))
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
weights = torch.softmax(logits, dim=-1)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)")
gate_values = self.gating_linear(m_part)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias,
weighted_avg.unsqueeze(-2))
output.append(self.o_linear(weighted_avg))
m = torch.cat(output, dim=1)
return m
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn import LayerNorm
from fastfold.habana.distributed import col_to_row, gather, row_to_col, scatter
from .kernel import bias_dropout_add, bias_ele_dropout_residual
from .ops import Linear, SelfAttention, Transition
def permute_final_dims(tensor, inds):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
class TriangleMultiplicationOutgoing(nn.Module):
def __init__(self, d_pair, p_drop, c=128):
super(TriangleMultiplicationOutgoing, self).__init__()
self.d_pair = d_pair
self.c = c
self.layernorm1 = LayerNorm(d_pair)
self.left_projection = Linear(d_pair, c)
self.right_projection = Linear(d_pair, c)
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
self.layernorm2 = LayerNorm(c)
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
self.p_drop = p_drop
def forward(self, Z_raw, Z_mask):
Z = self.layernorm1(Z_raw)
left_proj_act = self.left_projection(Z)
right_proj_act = self.right_projection(Z)
left_proj_act = Z_mask.unsqueeze(-1) * left_proj_act
right_proj_act = Z_mask.unsqueeze(-1) * right_proj_act
left_proj_act *= torch.sigmoid(self.left_gate(Z))
right_proj_act *= torch.sigmoid(self.right_gate(Z))
right_proj_act = gather(right_proj_act.contiguous(), dim=1)
g = torch.sigmoid(self.output_gate(Z))
p = torch.matmul(
permute_final_dims(left_proj_act, (2, 0, 1)),
permute_final_dims(right_proj_act, (2, 1, 0)),
)
ab = permute_final_dims(p, (1, 2, 0))
# ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
class TriangleMultiplicationIncoming(nn.Module):
def __init__(self, d_pair, p_drop, c=128):
super(TriangleMultiplicationIncoming, self).__init__()
self.d_pair = d_pair
self.c = c
self.layernorm1 = LayerNorm(d_pair)
self.left_projection = Linear(d_pair, c)
self.right_projection = Linear(d_pair, c)
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
self.layernorm2 = LayerNorm(c)
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
self.p_drop = p_drop
def forward(self, Z_raw, Z_mask):
Z = self.layernorm1(Z_raw)
left_proj_act = self.left_projection(Z)
right_proj_act = self.right_projection(Z)
left_proj_act = Z_mask.unsqueeze(-1) * left_proj_act
right_proj_act = Z_mask.unsqueeze(-1) * right_proj_act
left_proj_act *= torch.sigmoid(self.left_gate(Z))
right_proj_act *= torch.sigmoid(self.right_gate(Z))
left_proj_act = gather(left_proj_act.contiguous(), dim=2)
g = torch.sigmoid(self.output_gate(Z))
p = torch.matmul(
permute_final_dims(left_proj_act, (2, 1, 0)),
permute_final_dims(right_proj_act, (2, 0, 1)),
)
ab = permute_final_dims(p, (1, 2, 0))
# ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
class TriangleAttentionStartingNode(nn.Module):
def __init__(self, d_pair, p_drop, c=32, n_head=4):
super(TriangleAttentionStartingNode, self).__init__()
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernorm1 = LayerNorm(d_pair)
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
std=1.0 / math.sqrt(d_pair))
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
self.attention = SelfAttention(qkv_dim=d_pair,
c=c,
n_head=n_head,
out_dim=d_pair,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
def forward(self, Z_raw, Z_mask):
Z = self.layernorm1(Z_raw)
b = F.linear(Z, self.linear_b_weights)
b = gather(b, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
Z = self.attention(Z, Z_mask, b)
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
class TriangleAttentionEndingNode(nn.Module):
def __init__(self, d_pair, p_drop, c=32, n_head=4):
super(TriangleAttentionEndingNode, self).__init__()
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernorm1 = LayerNorm(d_pair)
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
std=1.0 / math.sqrt(d_pair))
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
self.attention = SelfAttention(qkv_dim=d_pair,
c=c,
n_head=n_head,
out_dim=d_pair,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
def forward(self, Z_raw, Z_mask):
Z = Z_raw.transpose(-2, -3)
Z_mask = Z_mask.transpose(-1, -2)
Z = self.layernorm1(Z)
b = F.linear(Z, self.linear_b_weights)
b = gather(b, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
Z = self.attention(Z, Z_mask, b)
Z = Z.transpose(-2, -3)
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
class PairStack(nn.Module):
def __init__(self, d_pair, p_drop=0.25):
super(PairStack, self).__init__()
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop)
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop)
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop)
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop)
self.PairTransition = Transition(d=d_pair)
def forward(self, pair, pair_mask):
pair_mask_row = scatter(pair_mask, dim=1)
pair_mask_col = scatter(pair_mask, dim=2)
pair = self.TriangleMultiplicationOutgoing(pair, pair_mask_row)
pair = row_to_col(pair)
pair = self.TriangleMultiplicationIncoming(pair, pair_mask_col)
pair = col_to_row(pair)
pair = self.TriangleAttentionStartingNode(pair, pair_mask_row)
pair = row_to_col(pair)
pair = self.TriangleAttentionEndingNode(pair, pair_mask_col)
pair = self.PairTransition(pair)
pair = col_to_row(pair)
return pair
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment