"src/lib/vscode:/vscode.git/clone" did not exist on "c49491e516aaa6023a21ae320c676048a32743cd"
Commit d5f3875c authored by Shenggan's avatar Shenggan
Browse files

init commit

parent 62ed1c4a
...@@ -127,3 +127,10 @@ dmypy.json ...@@ -127,3 +127,10 @@ dmypy.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
# vscode
.vscode/
# setup
dist/
build/
\ No newline at end of file
![](/assets/fold.jpg)
# FastFold # FastFold
![](https://img.shields.io/github/v/release/hpcaitech/FastFold)
[![GitHub license](https://img.shields.io/github/license/hpcaitech/FastFold.svg)](https://github.com/hpcaitech/FastFold/blob/master/LICENSE)
![](https://img.shields.io/badge/Made%20with-ColossalAI-blueviolet?style=flat)
Optimizing Protein Structure Prediction Model Training and Inference on GPU Clusters Optimizing Protein Structure Prediction Model Training and Inference on GPU Clusters
FastFold provides a **high-performance implementation of Evoformer** with the following characteristics.
1. Excellent kernel performance on GPU platform
2. Supporting Dynamic Axial Parallelism(DAP)
* Break the memory limit of single GPU and reduce the overall training time
* Distributed inference can significantly speed up inference and make extremely long sequence inference possible
3. Ease of use
* Replace a few lines and you can use FastFold in your project
* You don't need to care about how the parallel part is implemented
## Installation
You will need Python 3.8 or later and [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 11.1 or above when you are installing from source.
We highly recommend installing an Anaconda or Miniconda environment and install PyTorch with conda:
```
conda create -n fastfold python=3.8
conda activate fastfold
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
```
You can get the FastFold source and install it with setuptools:
```shell
git clone https://github.com/hpcaitech/FastFold
cd FastFold
python setup.py install --cuda_ext
```
## Performance Benchmark
We have included a performance benchmark script in `./benchmark`. You can benchmark the performance of Evoformer using different settings.
```shell
cd ./benchmark
torchrun --nproc_per_node=1 perf.py --msa-length 128 --res-length 256
```
If you want to benchmark with [OpenFold](https://github.com/aqlaboratory/openfold), you need to install OpenFold first and benchmark with option `--openfold`:
```shell
torchrun --nproc_per_node=1 perf.py --msa-length 128 --res-length 256 --openfold
```
## Cite us
Cite this paper, if you use FastFold in your research publication.
```
```
\ No newline at end of file
import argparse
import os
import torch
import torch.nn as nn
from fastfold.distributed import init_shadowcore
from fastfold.model import Evoformer
def main():
parser = argparse.ArgumentParser(description='MSA Attention Standalone Perf Benchmark')
parser.add_argument("--dap-size", default=1, type=int)
parser.add_argument('--batch-size', default=1, type=int, help='batch size')
parser.add_argument('--msa-length', default=132, type=int, help='Sequence Length of Input')
parser.add_argument('--res-length',
default=256,
type=int,
help='Start Range of Number of Sequences')
parser.add_argument('--trials', default=50, type=int, help='Number of Trials to Execute')
parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')
parser.add_argument('--layers',
default=12,
type=int,
help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
parser.add_argument('--cm', default=256, type=int, help='MSA hidden dimension')
parser.add_argument('--cz', default=128, type=int, help='Pair hidden dimension')
parser.add_argument('--heads', default=8, type=int, help='Number of Multihead Attention heads')
parser.add_argument('--openfold',
action='store_true',
help='torch.nn.MultitheadAttention Version.')
parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')
parser.add_argument('--prof', action='store_true', help='Only execute Fwd Pass.')
args = parser.parse_args()
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.global_rank = torch.distributed.get_rank()
print(
'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.global_rank, args.world_size))
init_shadowcore(args.tensor_model_parallel_size)
precision = torch.bfloat16
if args.tensor_model_parallel_size > 1:
# (PyTorch issue) Currently All2All communication does not support the Bfloat16 datatype in PyTorch
precision = torch.float16
if not torch.cuda.is_available():
raise NotImplementedError('Running on CPU is not supported')
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
if args.openfold:
from openfold.model.evoformer import EvoformerBlock
class OpenFoldEvoformer(nn.Module):
def __init__(self, d_node, d_pair):
super(OpenFoldEvoformer, self).__init__()
self.d_node = d_node
self.d_pair = d_pair
self.c_hidden_msa_att = int(d_node / 8)
self.c_hidden_pair_att = int(d_pair / 8)
self.EvoformerBlock = EvoformerBlock(c_m=d_node,
c_z=d_pair,
c_hidden_msa_att=self.c_hidden_msa_att,
c_hidden_opm=self.c_hidden_msa_att,
c_hidden_mul=self.d_pair,
c_hidden_pair_att=self.c_hidden_pair_att,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.25,
inf=1e9,
eps=1e-10)
def forward(self, node, pair, node_mask, pair_mask):
node, pair = self.EvoformerBlock(node, pair, node_mask, pair_mask)
return node, pair
attn_layers = []
for idx in range(0, args.layers):
if args.openfold:
attn_layers.append(OpenFoldEvoformer(d_node=args.cm, d_pair=args.cz))
else:
attn_layers.append(Evoformer(d_node=args.cm, d_pair=args.cz))
attn_layers[idx].cuda()
attn_layers[idx].to(dtype=precision)
start_evt_fwd = []
start_evt_bwd = []
stop_evt_bwd = []
for recorded_trial in range(0, args.trials):
start_evt_fwd.append(torch.cuda.Event(enable_timing=True))
start_evt_bwd.append(torch.cuda.Event(enable_timing=True))
stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))
inputs_node = torch.randn(args.batch_size,
args.msa_length // args.tensor_model_parallel_size,
args.res_length,
args.cm,
dtype=precision,
device=torch.device("cuda")).requires_grad_(True)
inputs_pair = torch.randn(args.batch_size,
args.res_length // args.tensor_model_parallel_size,
args.res_length,
args.cz,
dtype=precision,
device=torch.device("cuda")).requires_grad_(True)
node_mask = torch.ones((args.batch_size, args.msa_length, args.res_length),
dtype=precision,
device=torch.device("cuda")).requires_grad_(False)
pair_mask = torch.ones((args.batch_size, args.res_length, args.res_length),
dtype=precision,
device=torch.device("cuda")).requires_grad_(False)
grads_node = torch.randn_like(inputs_pair)
if args.prof:
prof = torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1,
warmup=args.warmup_trials,
active=args.trials,
repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/fastfold'),
profile_memory=False,
record_shapes=False,
with_stack=False)
prof.start()
for trial in range(0, args.trials + args.warmup_trials):
layer_inputs = inputs_node, inputs_pair
evt_idx = trial - args.warmup_trials
torch.distributed.barrier()
torch.cuda.synchronize()
if evt_idx >= 0:
start_evt_fwd[evt_idx].record()
for lyr_idx in range(0, args.layers):
layer_inputs = attn_layers[lyr_idx].forward(*layer_inputs, node_mask, pair_mask)
torch.cuda.synchronize()
if evt_idx >= 0:
start_evt_bwd[evt_idx].record()
if not args.fwd:
layer_inputs[1].backward(grads_node)
if evt_idx >= 0:
stop_evt_bwd[evt_idx].record()
if args.prof:
prof.step()
if args.prof:
prof.stop()
torch.distributed.barrier()
torch.cuda.synchronize()
elapsed_time_fwd = 0.0
elapsed_time_bwd = 0.0
for evt_idx in range(0, args.trials):
elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx])
elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx])
print("[ MSA Attn ] Input: {:4d}, {:4d}, {:4d}, ({:4d} {:4d}) Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms".format(
args.batch_size, args.msa_length, args.res_length, \
args.cm, args.cz, \
elapsed_time_fwd / ( args.trials * args.layers ), \
elapsed_time_bwd / ( args.trials * args.layers )))
if __name__ == '__main__':
main()
from .core import (init_shadowcore, shadowcore_is_initialized, get_tensor_model_parallel_group,
get_data_parallel_group, get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank, get_data_parallel_world_size,
get_data_parallel_rank, get_tensor_model_parallel_src_rank)
from .comm import (_reduce, _split, _gather, copy, scatter, reduce, gather, col_to_row, row_to_col)
__all__ = [
'init_shadowcore', 'shadowcore_is_initialized', 'get_tensor_model_parallel_group',
'get_data_parallel_group', 'get_tensor_model_parallel_world_size',
'get_tensor_model_parallel_rank', 'get_data_parallel_world_size', 'get_data_parallel_rank',
'get_tensor_model_parallel_src_rank', '_reduce', '_split', '_gather', 'copy', 'scatter',
'reduce', 'gather', 'col_to_row', 'row_to_col'
]
\ No newline at end of file
from typing import Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from .core import (get_tensor_model_parallel_group, get_tensor_model_parallel_src_rank,
get_tensor_model_parallel_world_size)
from .core import ensure_divisibility
def divide(numerator, denominator):
ensure_divisibility(numerator, denominator)
return numerator // denominator
def _reduce(tensor: Tensor) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor
dist.all_reduce(tensor,
op=dist.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
async_op=False)
return tensor
def _split(tensor: Tensor, dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor
split_size = divide(tensor.shape[dim], get_tensor_model_parallel_world_size())
tensor_list = torch.split(tensor, split_size, dim=dim)
output = tensor_list[get_tensor_model_parallel_src_rank()].contiguous()
return output
def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor
if dim == 1:
output_shape = list(tensor.shape)
output_shape[1] *= get_tensor_model_parallel_world_size()
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
dist.all_gather(list(tensor_list), tensor, group=get_tensor_model_parallel_group(), async_op=False)
else:
tensor_list = [torch.ones_like(tensor) for _ in range(get_tensor_model_parallel_world_size())]
dist.all_gather(tensor_list, tensor, group=get_tensor_model_parallel_group(), async_op=False)
output = torch.cat(tensor_list, dim=dim)
return output
def copy(input: Tensor) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Copy.apply(input)
return input
class Copy(torch.autograd.Function):
@staticmethod
def forward(ctx: "Copy", input: Tensor) -> Tensor:
return input
@staticmethod
def backward(ctx: "Copy", grad_output: Tensor) -> Tensor:
return _reduce(grad_output)
def scatter(input: Tensor, dim: int = -1) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Scatter.apply(input, dim)
else:
input = _split(input, dim=dim)
return input
class Scatter(torch.autograd.Function):
@staticmethod
def forward(ctx: "Scatter", input: Tensor, dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([dim]))
return _split(input, dim=dim)
@staticmethod
def backward(ctx: "Scatter", grad_output: Tensor) -> Tuple[Tensor]:
dim, = ctx.saved_tensors
return _gather(grad_output, dim=int(dim)), None
def reduce(input: Tensor) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Reduce.apply(input)
else:
input = _reduce(input)
return input
class Reduce(torch.autograd.Function):
@staticmethod
def forward(ctx: "Reduce", input: Tensor) -> Tensor:
return _reduce(input)
@staticmethod
def backward(ctx: "Reduce", grad_output: Tensor) -> Tensor:
return grad_output
def gather(input: Tensor, dim: int = -1) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Gather.apply(input, dim)
else:
input = _gather(input, dim=dim)
return input
class Gather(torch.autograd.Function):
@staticmethod
def forward(ctx: "Gather", input: Tensor, dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([dim]))
return _gather(input, dim=dim)
@staticmethod
def backward(ctx: "Gather", grad_output: Tensor) -> Tuple[Tensor]:
dim, = ctx.saved_tensors
return _split(grad_output, dim=int(dim)), None
def _all_to_all(tensor: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor
split_size = divide(tensor.shape[in_dim], get_tensor_model_parallel_world_size())
input_tensor_list = torch.split(tensor, split_size, dim=in_dim)
input_tensor_list = [tensor_.contiguous() for tensor_ in input_tensor_list]
if out_dim == 1:
output_shape = list(input_tensor_list[0].shape)
output_shape[1] *= get_tensor_model_parallel_world_size()
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
output_tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
dist.all_to_all(list(output_tensor_list),
input_tensor_list,
group=get_tensor_model_parallel_group(),
async_op=False)
else:
output_tensor_list = [torch.ones_like(tensor_) for tensor_ in input_tensor_list]
dist.all_to_all(output_tensor_list,
input_tensor_list,
group=get_tensor_model_parallel_group(),
async_op=False)
output = torch.cat(output_tensor_list, dim=out_dim)
return output
def col_to_row(input_: Tensor) -> Tensor:
if torch.is_grad_enabled() and input_.requires_grad:
input_ = All_to_All.apply(input_, 1, 2)
else:
input_ = _all_to_all(input_, in_dim=1, out_dim=2)
return input_
def row_to_col(input_: Tensor) -> Tensor:
if torch.is_grad_enabled() and input_.requires_grad:
input_ = All_to_All.apply(input_, 2, 1)
else:
input_ = _all_to_all(input_, in_dim=2, out_dim=1)
return input_
class All_to_All(torch.autograd.Function):
@staticmethod
def forward(ctx: "All_to_All", input_: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([in_dim, out_dim]))
return _all_to_all(input_, in_dim=in_dim, out_dim=out_dim)
@staticmethod
def backward(ctx: "All_to_All", grad_output: Tensor) -> Tuple[Tensor]:
saved_tensors = ctx.saved_tensors[0]
return _all_to_all(grad_output, in_dim=int(saved_tensors[1]),
out_dim=int(saved_tensors[0])), None, None
from typing import Tuple
from einops import rearrange
import torch
import torch.distributed as dist
from torch import Tensor
from .core import get_tensor_model_parallel_world_size, get_tensor_model_parallel_group
from .comm import _split, divide
def _gather_async(tensor: Tensor, dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor, None
output_shape = list(tensor.shape)
output_shape[1] *= get_tensor_model_parallel_world_size()
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
work = dist.all_gather(list(tensor_list),
tensor,
group=get_tensor_model_parallel_group(),
async_op=True)
return output, work
def gather_async(input: Tensor, dim: int = -1) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input, work = GatherAsync.apply(input, dim)
else:
input, work = _gather_async(input, dim=dim)
return input, work
def gather_async_opp(output: Tensor, work, dim: int = -1) -> Tensor:
if work:
work.wait()
if dim == 2:
output = GatherAsyncOpp.apply(output)
return output
class GatherAsyncOpp(torch.autograd.Function):
@staticmethod
def forward(ctx: "GatherAsyncOpp", input: Tensor) -> Tensor:
mp_size = get_tensor_model_parallel_world_size()
output = rearrange(input, 'n (x h) w c -> n h (x w) c', x=mp_size)
return output
@staticmethod
def backward(ctx: "GatherAsyncOpp", grad_output: Tensor) -> Tuple[Tensor]:
mp_size = get_tensor_model_parallel_world_size()
n, h, w, c = grad_output.shape
return grad_output.resize_(n, h * mp_size, int(w / mp_size), c)
class GatherAsync(torch.autograd.Function):
@staticmethod
def forward(ctx: "GatherAsync", input: Tensor, dim: int = -1) -> Tensor:
ctx.dim = dim
return _gather_async(input, dim=dim)
@staticmethod
def backward(ctx: "GatherAsync", grad_output: Tensor, grad_work=None) -> Tuple[Tensor]:
if ctx.dim == 2:
mp_size = get_tensor_model_parallel_world_size()
n, h, w, c = grad_output.shape
grad_output.resize_(n, int(h / mp_size), w * mp_size, c)
return _split(grad_output, dim=ctx.dim), None
def _all_to_all_async(tensor: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor, None
split_size = divide(tensor.shape[in_dim], get_tensor_model_parallel_world_size())
input_tensor_list = torch.split(tensor, split_size, dim=in_dim)
input_tensor_list = [tensor_.contiguous() for tensor_ in input_tensor_list]
output_shape = list(input_tensor_list[0].shape)
output_shape[1] *= get_tensor_model_parallel_world_size()
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
output_tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
work = dist.all_to_all(list(output_tensor_list),
input_tensor_list,
group=get_tensor_model_parallel_group(),
async_op=True)
return output, work
WORLD_WORK_ALL2ALL = None
class All_to_All_Async(torch.autograd.Function):
@staticmethod
def forward(ctx: "All_to_All_Async",
input_: Tensor,
in_dim: int = -1,
out_dim: int = -1) -> Tensor:
ctx.in_dim = in_dim
ctx.out_dim = out_dim
return _all_to_all_async(input_, in_dim=in_dim, out_dim=out_dim)
@staticmethod
def backward(ctx: "All_to_All_Async", grad_output: Tensor, grad_work=None) -> Tuple[Tensor]:
global WORLD_WORK_ALL2ALL
if WORLD_WORK_ALL2ALL:
WORLD_WORK_ALL2ALL.wait()
WORLD_WORK_ALL2ALL = None
if ctx.in_dim == 2:
mp_size = get_tensor_model_parallel_world_size()
grad_output = rearrange(grad_output, 'n (x h) w c -> n h (x w) c', x=mp_size)
return grad_output, None, None
class All_to_All_Async_Opp(torch.autograd.Function):
@staticmethod
def forward(ctx: "All_to_All_Async_Opp",
output: Tensor,
work,
in_dim: int = -1,
out_dim: int = -1) -> Tensor:
ctx.in_dim = in_dim
ctx.out_dim = out_dim
if work:
work.wait()
if out_dim == 2:
mp_size = get_tensor_model_parallel_world_size()
output = rearrange(output, 'n (x h) w c -> n h (x w) c', x=mp_size)
return output
@staticmethod
def backward(ctx: "All_to_All_Async_Opp", grad_output: Tensor) -> Tuple[Tensor]:
global WORLD_WORK_ALL2ALL
d_tensor, WORLD_WORK_ALL2ALL = _all_to_all_async(grad_output,
in_dim=ctx.out_dim,
out_dim=ctx.in_dim)
return d_tensor, None, None, None
\ No newline at end of file
import torch.distributed as dist
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# These values enable us to change the mpu sizes on the fly.
_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_TENSOR_MODEL_PARALLEL_RANK = None
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator)
def init_shadowcore(tensor_model_parallel_size_=1):
assert dist.is_initialized()
world_size = dist.get_world_size()
rank = dist.get_rank()
# check dist config
ensure_divisibility(world_size, tensor_model_parallel_size_)
data_parallel_size_ = world_size // tensor_model_parallel_size_
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized'
for i in range(tensor_model_parallel_size_):
ranks = range(i, world_size, tensor_model_parallel_size_)
group = dist.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
'tensor model parallel group is already initialized'
# Build the model-parallel groups.
for i in range(data_parallel_size_):
ranks = range(i * tensor_model_parallel_size_, (i + 1) * tensor_model_parallel_size_)
group = dist.new_group(ranks)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
if dist.get_rank() == 0:
print('> initialize tensor model parallel with size {}'.format(tensor_model_parallel_size_))
print('> initialize data parallel with size {}'.format(data_parallel_size_))
def shadowcore_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
_DATA_PARALLEL_GROUP is None:
return False
return True
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
'intra_layer_model parallel group is not initialized'
return _TENSOR_MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _TENSOR_MODEL_PARALLEL_WORLD_SIZE
return dist.get_world_size(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _TENSOR_MODEL_PARALLEL_RANK
if _TENSOR_MODEL_PARALLEL_RANK is not None:
return _TENSOR_MODEL_PARALLEL_RANK
return dist.get_rank(group=get_tensor_model_parallel_group())
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return dist.get_world_size(group=get_data_parallel_group())
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return dist.get_rank(group=get_data_parallel_group())
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = dist.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
from .msa import MSAStack
from .ops import OutProductMean
from .triangle import PairStack
from .evoformer import Evoformer
__all__ = ['MSAStack', 'OutProductMean', 'PairStack', 'Evoformer']
import torch.nn as nn
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.model import MSAStack, OutProductMean, PairStack
class Evoformer(nn.Module):
def __init__(self, d_node, d_pair):
super(Evoformer, self).__init__()
self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15)
self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32)
self.pair_stack = PairStack(d_pair=d_pair)
def forward(self, node, pair, node_mask, pair_mask):
node = self.msa_stack(node, pair, node_mask)
pair = pair + self.communication(node, node_mask)
node, work = All_to_All_Async.apply(node, 1, 2)
pair = self.pair_stack(pair, pair_mask)
node = All_to_All_Async_Opp.apply(node, work, 1, 2)
return node, pair
\ No newline at end of file
import math
import numpy as np
import torch.nn as nn
def glorot_uniform_af(x, gain=1.0):
"""
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
In PyTorch:
[feature_out, feature_in, n_head ...]
In Jax:
[... n_head, feature_in, feature_out]
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
[feature_in, n_head, feature_out]
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
"""
fan_in, fan_out = x.shape[-2:]
if len(x.shape) > 2:
receptive_field_size = np.prod(x.shape[:-2])
fan_in *= receptive_field_size
fan_out *= receptive_field_size
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
nn.init.uniform_(x, -dev, dev)
return x
from .jit.fused_ops import bias_dropout_add, bias_sigmod_ele, bias_ele_dropout_residual
from .cuda_native.layer_norm import MixedFusedLayerNorm as LayerNorm
from .cuda_native.softmax import softmax, scale_mask_softmax, scale_mask_bias_softmax
__all__ = [
"bias_dropout_add", "bias_sigmod_ele", "bias_ele_dropout_residual", "LayerNorm", "softmax",
"scale_mask_softmax", "scale_mask_bias_softmax"
]
\ No newline at end of file
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
\ No newline at end of file
This diff is collapsed.
#include <torch/extension.h>
#include <cassert>
#include <vector>
#include "compat.h"
void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int& n1, int& n2) {
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert(input.sizes()[i + idiff] == normalized_shape[i]);
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}
void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta) {
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int& n1, int& n2) {
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape=" << normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape << ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input, normalized_shape, n1, n2);
}
void check_args(at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma,
at::Tensor beta, int& n1, int& n2) {
check_args(input, normalized_shape, n1, n2);
check_args(normalized_shape, gamma, beta);
}
void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, at::Tensor* input,
int n1, int n2, at::IntArrayRef normalized_shape, at::Tensor* gamma,
at::Tensor* beta, double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm_affine(at::Tensor input, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta, double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean = at::empty({n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape, &gamma, &beta,
epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar,
at::Tensor* input, int n1, int n2, at::IntArrayRef normalized_shape,
at::Tensor* gamma, at::Tensor* beta, double epsilon,
at::Tensor* grad_input, at::Tensor* grad_gamma,
at::Tensor* grad_beta);
std::vector<at::Tensor> layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean,
at::Tensor invvar, at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, normalized_shape, &gamma, &beta,
epsilon, &grad_input, &grad_gamma, &grad_beta);
return {grad_input, grad_gamma, grad_beta};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)");
}
\ No newline at end of file
#include <cooperative_groups.h>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
#include "layer_norm.cuh"
#include "type_shim.h"
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, at::Tensor* input,
int n1, int n2, at::IntArrayRef normalized_shape, at::Tensor* gamma,
at::Tensor* beta, double epsilon) {
at::Tensor normalized = at::empty_like(*output);
fastfold::layer_norm::DirectLoad<at::BFloat16, float> load((at::BFloat16*)input->data_ptr(),
n2);
fastfold::layer_norm::AffineStore<float, at::BFloat16, true, true> store(
(at::BFloat16*)normalized.data_ptr(), (at::BFloat16*)output->data_ptr(), n2,
(at::BFloat16*)gamma->data_ptr(), (at::BFloat16*)beta->data_ptr());
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::layer_norm::DispatchLayerNorm<decltype(load), decltype(store), float>(
cuda_stream, load, store, n1, n2, epsilon, (float*)mean->data_ptr(),
(float*)invvar->data_ptr());
}
template <typename T>
struct SharedMemory;
template <>
struct SharedMemory<float> {
__device__ float* getPointer() {
extern __shared__ float s_float[];
return s_float;
}
};
template <typename T, typename U, typename V>
__device__ void cuLoadWriteStridedInputs(const int i1_block, const int thr_load_row_off,
const int thr_load_col_off, const int i2_off,
const int row_stride, U* warp_buf1, U* warp_buf2,
const T* input, const V* dout, const int i1_end,
const int n2, const U* __restrict__ mean,
const U* __restrict__ invvar) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
}
template <typename T, typename U, typename V>
__device__ void cuLoadAddStridedInputs(const int i1_block, const int thr_load_row_off,
const int thr_load_col_off, const int i2_off,
const int row_stride, U* warp_buf1, U* warp_buf2,
const T* input, const V* dout, const int i1_end,
const int n2, const U* __restrict__ mean,
const U* __restrict__ invvar) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
}
template <typename T, typename U, typename V>
__global__ void cuComputePartGradGammaBeta(const V* __restrict__ dout, const T* __restrict__ input,
const int n1, const int n2, const U* __restrict__ mean,
const U* __restrict__ invvar, U epsilon,
U* part_grad_gamma, U* part_grad_beta) {
const int numsegs_n1 = (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
const int i1_beg_plus_one = (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x + 1;
const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);
const int thr_load_row_off = (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y +
// (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U* warp_buf1 = (U*)buf;
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, row_stride,
warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar);
for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
i1_block += blockDim.y * blockDim.y) {
cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride,
warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k * blockDim.y;
int idx1 = row1 * row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template <typename U, typename V>
__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, const U* part_grad_beta,
const int part_size, const int n1, const int n2,
V* grad_gamma, V* grad_beta) {
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U* part_grad_gamma_ptr =
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
sum_beta += part_grad_beta_ptr[warp_offset * n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx + nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx + nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template <typename T, typename U, typename V>
void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, at::Tensor* input, int n1,
int n2, const V* gamma, const V* beta, double epsilon, T* grad_input,
V* grad_gamma, V* grad_beta) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32, 4, 1);
const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma =
at::empty({part_size, n2}, input->options().dtype(at::ScalarType::Float));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, input->DATA_PTR<T>(), n1, n2, mean, invvar, U(epsilon),
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>());
const dim3 threads3(32, 8, 1);
const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>(), part_size, n1, n2,
grad_gamma, grad_beta);
}
}
void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar,
at::Tensor* input, int n1, int n2, at::IntArrayRef normalized_shape,
at::Tensor* gamma, at::Tensor* beta, double epsilon,
at::Tensor* grad_input, at::Tensor* grad_gamma,
at::Tensor* grad_beta) {
at::Tensor add_to_output = at::empty_like(*grad_input);
fastfold::layer_norm::DirectLoad<at::BFloat16, float> load_x((at::BFloat16*)input->data_ptr(),
n2);
fastfold::layer_norm::ScaleLoad<at::BFloat16, float, true> load_scaled_dy(
(at::BFloat16*)dout->data_ptr(), (at::BFloat16*)gamma->data_ptr(), n2);
fastfold::layer_norm::AddStore<float, at::BFloat16, true> store(
(at::BFloat16*)add_to_output.data_ptr(), (at::BFloat16*)grad_input->data_ptr(), n2);
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::layer_norm::DispatchLayerNormGrad<decltype(load_x), decltype(load_scaled_dy),
decltype(store), float>(
cuda_stream, load_x, load_scaled_dy, store, (float*)mean->data_ptr(),
(float*)invvar->data_ptr(), n1, n2);
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), gamma->scalar_type(), "cuda_layer_norm_gradient_kernel",
HostLayerNormGradient(dout->DATA_PTR<scalar_t_out>(), mean->DATA_PTR<float>(),
invvar->DATA_PTR<float>(), input, n1, n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL, epsilon,
grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);)
}
\ No newline at end of file
This diff is collapsed.
#include <torch/extension.h>
at::Tensor softmax(at::Tensor input, int rows, int cols);
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor input, int rows, int cols);
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, int rows, int cols,
float scale);
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask,
int rows, int cols, float scale);
at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
int rows, int cols, float scale);
at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor input,
at::Tensor mask, at::Tensor bias, int rows,
int cols, float scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &softmax, "Softmax forward (CUDA)");
m.def("backward_affine", &softmax_gradient, "Softmax backward (CUDA)");
m.def("fused_scale_mask_softmax_forward", &fused_scale_mask_softmax_forward,
"Softmax forward (CUDA)");
m.def("fused_scale_mask_softmax_backward", &fused_scale_mask_softmax_backward,
"Softmax forward (CUDA)");
m.def("fused_scale_mask_bias_softmax_forward", &fused_scale_mask_bias_softmax_forward,
"Softmax forward (CUDA)");
m.def("fused_scale_mask_bias_softmax_backward", &fused_scale_mask_bias_softmax_backward,
"Softmax forward (CUDA)");
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment