Commit d5f3875c authored by Shenggan's avatar Shenggan
Browse files

init commit

parent 62ed1c4a
#include <iostream>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
#include "softmax.cuh"
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
at::Tensor softmax(at::Tensor input, int rows, int cols) {
CHECK_INPUT(input);
at::Tensor output = at::empty_like(input);
fastfold::softmax::DirectLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectStore<float, at::BFloat16> store((at::BFloat16 *)output.data_ptr(),
int64_t(cols));
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmax<decltype(load), decltype(store), float>(cuda_stream, load,
store, rows, cols);
return output;
}
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor input, int rows, int cols) {
CHECK_INPUT(input);
at::Tensor grad_input = at::empty_like(input);
fastfold::softmax::DirectLoad<at::BFloat16, float> load_d((at::BFloat16 *)d_output.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectStore<float, at::BFloat16> store((at::BFloat16 *)grad_input.data_ptr(),
int64_t(cols));
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmaxGrad<decltype(load), decltype(load_d), decltype(store),
float>(cuda_stream, load, load_d, store, rows, cols);
return grad_input;
}
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, int rows, int cols,
float scale) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
int head = input.sizes()[2];
at::Tensor output = at::empty_like(input);
// (const SRC* src, const int8_t* mask, int64_t row_size, SRC scale)
fastfold::softmax::ScaleMaskLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
(at::BFloat16 *)mask.data_ptr(),
int64_t(cols), int64_t(head), scale);
fastfold::softmax::DirectStore<float, at::BFloat16> store((at::BFloat16 *)output.data_ptr(),
int64_t(cols));
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmax<decltype(load), decltype(store), float>(cuda_stream, load,
store, rows, cols);
return output;
}
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask,
int rows, int cols, float scale) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
int head = input.sizes()[2];
at::Tensor grad_input = at::empty_like(input);
fastfold::softmax::DirectLoad<at::BFloat16, float> load_d((at::BFloat16 *)d_output.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
int64_t(cols));
// (DST* dst, const int8_t* mask, int64_t row_size, DST scale)
fastfold::softmax::ScaleMaskStore<float, at::BFloat16> store(
(at::BFloat16 *)grad_input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), int64_t(cols),
int64_t(head), scale);
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmaxGrad<decltype(load), decltype(load_d), decltype(store),
float>(cuda_stream, load, load_d, store, rows, cols);
return grad_input;
}
at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
int rows, int cols, float scale) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
CHECK_INPUT(bias);
int head = input.sizes()[2];
at::Tensor output = at::empty_like(input);
// (const SRC* src, const int8_t* mask, int64_t row_size, SRC scale)
fastfold::softmax::ScaleMaskBiasLoad<at::BFloat16, float> load(
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)bias.data_ptr(), int64_t(cols), int64_t(head), scale);
fastfold::softmax::DirectStore<float, at::BFloat16> store((at::BFloat16 *)output.data_ptr(),
int64_t(cols));
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmax<decltype(load), decltype(store), float>(cuda_stream, load,
store, rows, cols);
return output;
}
at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor input,
at::Tensor mask, at::Tensor bias, int rows,
int cols, float scale) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
int head = input.sizes()[2];
// CHECK_INPUT(bias);
at::Tensor grad_input = at::empty_like(input);
fastfold::softmax::DirectLoad<at::BFloat16, float> load_d((at::BFloat16 *)d_output.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
int64_t(cols));
// (DST* dst, const int8_t* mask, int64_t row_size, DST scale)
fastfold::softmax::ScaleMaskStore<float, at::BFloat16> store(
(at::BFloat16 *)grad_input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), int64_t(cols),
int64_t(head), scale);
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmaxGrad<decltype(load), decltype(load_d), decltype(store),
float>(cuda_stream, load, load_d, store, rows, cols);
return grad_input;
}
\ No newline at end of file
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: { \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes(T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
\ No newline at end of file
import importlib
import numbers
import torch
from torch.nn import init
from torch.nn.parameter import Parameter
global fastfold_layer_norm_cuda
fastfold_layer_norm_cuda = None
class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fastfold_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= fastfold_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super(MixedFusedLayerNorm, self).__init__()
global fastfold_layer_norm_cuda
if fastfold_layer_norm_cuda is None:
try:
fastfold_layer_norm_cuda = importlib.import_module("fastfold_layer_norm_cuda")
except ImportError:
raise RuntimeError('MixedFusedLayerNorm requires cuda extensions')
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias,
self.normalized_shape, self.eps)
import importlib
from functools import reduce
from operator import mul
import torch
fastfold_softmax_cuda = importlib.import_module("fastfold_softmax_cuda")
class SoftmaxAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
input_ = input.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.forward_affine(input_, ctx.rows, ctx.cols)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
output = ctx.saved_tensors[0]
grad_input = None
grad_input = fastfold_softmax_cuda.backward_affine(grad_output.contiguous(), output,
ctx.rows, ctx.cols)
return grad_input
class FusedScaleMaskSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, scale):
input_ = input.contiguous()
mask_ = mask.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.fused_scale_mask_softmax_forward(
input_, mask_, ctx.rows, ctx.cols, scale)
ctx.save_for_backward(output, mask_)
ctx.scale = scale
return output
@staticmethod
def backward(ctx, grad_output):
output, mask_ = ctx.saved_tensors
grad_input = None
grad_input = fastfold_softmax_cuda.fused_scale_mask_softmax_backward(
grad_output.contiguous(), output, mask_, ctx.rows, ctx.cols, ctx.scale)
return grad_input.contiguous(), None, None
class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, bias, scale):
input_ = input.contiguous()
mask_ = mask.contiguous()
bias_ = bias.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.fused_scale_mask_bias_softmax_forward(
input_, mask_, bias_, ctx.rows, ctx.cols, scale)
ctx.save_for_backward(output, mask_, bias_)
ctx.scale = scale
return output
@staticmethod
def backward(ctx, grad_output):
output, mask_, bias_ = ctx.saved_tensors
grad_input = None
grad_input = fastfold_softmax_cuda.fused_scale_mask_bias_softmax_backward(
grad_output.contiguous(), output, mask_, bias_, ctx.rows, ctx.cols, ctx.scale)
grad_input = grad_input.contiguous()
grad_bias = torch.sum(grad_input, dim=1, keepdim=True)
return grad_input.contiguous(), grad_bias, None, None
softmax = SoftmaxAffineFunction.apply
scale_mask_softmax = FusedScaleMaskSoftmaxFunction.apply
scale_mask_bias_softmax = FusedScaleMaskBiasSoftmaxFunction.apply
from .options import _set_jit_fusion_options
_set_jit_fusion_options()
\ No newline at end of file
import torch
import torch.nn.functional as F
@torch.jit.script
def bias_sigmod_ele(y, bias, z):
return torch.sigmoid(y + bias) * z
@torch.jit.script
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
residual: torch.Tensor, prob: float) -> torch.Tensor:
out = (x + bias) * F.dropout(dropmask, p=prob, training=True)
out = residual + out
return out
@torch.jit.script
def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
dropout_mask: torch.Tensor, Z_raw: torch.Tensor,
prob: float) -> torch.Tensor:
return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b))
import torch
JIT_OPTIONS_SET = False
def _set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
global JIT_OPTIONS_SET
if JIT_OPTIONS_SET == False:
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# # nvfuser
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
# torch._C._jit_override_can_fuse_on_cpu(False)
# torch._C._jit_override_can_fuse_on_gpu(False)
# torch._C._jit_set_texpr_fuser_enabled(False)
# torch._C._jit_set_nvfuser_enabled(True)
# torch._C._debug_set_autodiff_subgraph_inlining(False)
# else:
# legacy pytorch fuser
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
JIT_OPTIONS_SET = True
\ No newline at end of file
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastfold.model.kernel import LayerNorm
from fastfold.model.ops import Transition, SelfAttention
from fastfold.model.kernel import bias_dropout_add
from fastfold.distributed import scatter, row_to_col
from fastfold.distributed.comm_async import gather_async
class MSARowAttentionWithPairBias(nn.Module):
def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
super(MSARowAttentionWithPairBias, self).__init__()
self.d_node = d_node
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernormM = LayerNorm(d_node)
self.layernormZ = LayerNorm(d_pair)
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
std=1.0 / math.sqrt(d_pair))
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)
self.attention = SelfAttention(qkv_dim=d_node,
c=c,
n_head=n_head,
out_dim=d_node,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)
def forward(self, M_raw, Z, M_mask):
## Input projections
M = self.layernormM(M_raw)
Z = self.layernormZ(Z)
b = F.linear(Z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
# b = rearrange(b, 'b q k h -> b h q k')
# padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :]
M = self.attention(M, M_mask, (b, work))
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)
class MSAColumnAttention(nn.Module):
def __init__(self, d_node, c=32, n_head=8):
super(MSAColumnAttention, self).__init__()
self.d_node = d_node
self.c = c
self.n_head = n_head
self.layernormM = LayerNorm(d_node)
self.attention = SelfAttention(qkv_dim=d_node,
c=c,
n_head=n_head,
out_dim=d_node,
gating=True)
def forward(self, M_raw, M_mask):
M = M_raw.transpose(-2, -3)
M = self.layernormM(M)
M_mask = M_mask.transpose(-1, -2)
# padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :]
M = self.attention(M, M_mask)
M = M.transpose(-2, -3)
return M_raw + M
class MSAStack(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15):
super(MSAStack, self).__init__()
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
d_pair=d_pair,
p_drop=p_drop)
self.MSAColumnAttention = MSAColumnAttention(d_node=d_node)
self.MSATransition = Transition(d=d_node)
def forward(self, node, pair, node_mask):
# split node in row-axis
node_mask_row = scatter(node_mask, dim=1)
node = self.MSARowAttentionWithPairBias(node, pair, node_mask_row)
node = row_to_col(node)
node_mask_col = scatter(node_mask, dim=2)
node = self.MSAColumnAttention(node, node_mask_col)
node = self.MSATransition(node)
return node
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from fastfold.model.kernel import scale_mask_softmax, scale_mask_bias_softmax
from fastfold.model.kernel import LayerNorm
from .initializer import glorot_uniform_af
from fastfold.model.kernel import bias_sigmod_ele
from fastfold.distributed import gather, scatter
from fastfold.distributed.comm_async import gather_async, gather_async_opp
class DropoutRowwise(nn.Module):
def __init__(self, p):
super(DropoutRowwise, self).__init__()
self.p = p
self.dropout = nn.Dropout(p=p)
def forward(self, x):
dropout_mask = torch.ones_like(x[:, 0:1, :, :])
dropout_mask = self.dropout(dropout_mask)
return dropout_mask * x
class DropoutColumnwise(nn.Module):
def __init__(self, p):
super(DropoutColumnwise, self).__init__()
self.p = p
self.dropout = nn.Dropout(p=p)
def forward(self, x):
dropout_mask = torch.ones_like(x[:, :, 0:1, :])
dropout_mask = self.dropout(dropout_mask)
return dropout_mask * x
class Transition(nn.Module):
def __init__(self, d, n=4):
super(Transition, self).__init__()
self.norm = LayerNorm(d)
self.linear1 = Linear(d, n * d, initializer='relu')
self.linear2 = Linear(n * d, d, initializer='zeros')
def forward(self, src):
x = self.norm(src)
x = self.linear2(F.relu(self.linear1(x)))
return src + x
class OutProductMean(nn.Module):
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32):
super(OutProductMean, self).__init__()
self.layernormM = LayerNorm(n_feat)
self.linear_a = Linear(n_feat, n_feat_proj)
self.linear_b = Linear(n_feat, n_feat_proj)
self.o_linear = Linear(n_feat_proj * n_feat_proj,
n_feat_out,
initializer='zero',
use_bias=True)
def forward(self, M, M_mask):
M = self.layernormM(M)
right_act = self.linear_b(M)
right_act_all, work = gather_async(right_act, dim=2)
# right_act_all = gather(right_act, dim=2)
left_act = self.linear_a(M)
M_mask = M_mask.unsqueeze(-1)
M_mask_col = scatter(M_mask, dim=2)
left_act = M_mask_col * left_act
norm = torch.einsum('bsid,bsjd->bijd', M_mask_col, M_mask)
right_act_all = gather_async_opp(right_act_all, work, dim=2)
right_act_all = M_mask * right_act_all
O = torch.einsum('bsid,bsje->bijde', left_act, right_act_all)
O = rearrange(O, 'b i j d e -> b i j (d e)')
Z = self.o_linear(O)
Z /= (1e-3 + norm)
return Z
class Linear(nn.Linear):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def __init__(
self,
feature_in: int,
feature_out: int,
initializer: str = 'linear',
use_bias: bool = True,
bias_init: float = 0.,
):
super(Linear, self).__init__(feature_in, feature_out, bias=use_bias)
self.use_bias = use_bias
if initializer == 'linear':
glorot_uniform_af(self.weight, gain=1.0)
elif initializer == 'relu':
glorot_uniform_af(self.weight, gain=2.0)
elif initializer == 'zeros':
nn.init.zeros_(self.weight)
if self.use_bias:
with torch.no_grad():
self.bias.fill_(bias_init)
class SelfAttention(nn.Module):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False):
super(SelfAttention, self).__init__()
self.qkv_dim = qkv_dim
self.c = c
self.n_head = n_head
self.out_dim = out_dim
self.gating = gating
self.last_bias_fuse = last_bias_fuse
self.scaling = self.c**(-0.5)
self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear', use_bias=False)
# self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
# self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
# self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
if gating:
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False)
self.o_linear = Linear(n_head * c,
out_dim,
initializer='zero',
use_bias=(not last_bias_fuse))
def forward(self, in_data, mask, nonbatched_bias=None):
"""
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
qkv = self.to_qkv(in_data).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
# q = self.to_q(in_data)
# k = self.to_k(in_data)
# v = self.to_k(in_data)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), [q, k, v])
# q = q * self.scaling
logits = torch.matmul(q, k.transpose(-1, -2))
# logits += mask
if nonbatched_bias is not None:
# logits += nonbatched_bias.unsqueeze(1)
bias = gather_async_opp(*nonbatched_bias, dim=1)
bias = rearrange(bias, 'b q k h -> b h q k')
weights = scale_mask_bias_softmax(logits, mask, bias.unsqueeze(1), self.scaling)
else:
weights = scale_mask_softmax(logits, mask, self.scaling)
# weights = torch.softmax(logits, dim=-1)
# weights = softmax(logits)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
if self.gating:
gate_values = self.gating_linear(in_data)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
output = self.o_linear(weighted_avg)
return output
from fastfold.distributed.comm_async import gather_async
import torch
import torch.nn as nn
from fastfold.model.kernel import LayerNorm
from fastfold.distributed.comm import col_to_row, row_to_col, scatter
from fastfold.model.kernel import bias_dropout_add, bias_ele_dropout_residual
from fastfold.model.ops import Linear, SelfAttention, Transition
from fastfold.distributed.comm_async import gather_async_opp, gather_async
def permute_final_dims(tensor, inds):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
class TriangleMultiplicationOutgoing(nn.Module):
def __init__(self, d_pair, p_drop, c=128):
super(TriangleMultiplicationOutgoing, self).__init__()
self.d_pair = d_pair
self.c = c
self.layernorm1 = LayerNorm(d_pair)
self.left_right_projection = Linear(d_pair, 2 * c)
self.left_right_gate = Linear(d_pair, 2 * c, initializer='zeros', bias_init=1.)
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
self.layernorm2 = LayerNorm(c)
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
self.p_drop = p_drop
def forward(self, Z_raw, Z_mask_row):
Z = self.layernorm1(Z_raw)
left_right_proj_act = self.left_right_projection(Z)
left_right_proj_act = Z_mask_row.unsqueeze(-1) * left_right_proj_act
left_right_proj_act *= torch.sigmoid(self.left_right_gate(Z))
left_proj_act, right_proj_act = left_right_proj_act.chunk(2, dim=-1)
# right_proj_act = gather(right_proj_act.contiguous(), dim=1)
right_proj_act, work = gather_async(right_proj_act.contiguous(), dim=1)
g = torch.sigmoid(self.output_gate(Z))
left_proj_act = permute_final_dims(left_proj_act, (2, 0, 1))
right_proj_act = gather_async_opp(right_proj_act, work, dim=1)
p = torch.matmul(
left_proj_act,
permute_final_dims(right_proj_act, (2, 1, 0)),
)
ab = permute_final_dims(p, (1, 2, 0))
# ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
dropout_mask,
Z_raw,
prob=self.p_drop)
class TriangleMultiplicationIncoming(nn.Module):
def __init__(self, d_pair, p_drop, c=128):
super(TriangleMultiplicationIncoming, self).__init__()
self.d_pair = d_pair
self.c = c
self.layernorm1 = LayerNorm(d_pair)
self.left_right_projection = Linear(d_pair, 2 * c)
self.left_right_gate = Linear(d_pair, 2 * c, initializer='zeros', bias_init=1.)
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
self.layernorm2 = LayerNorm(c)
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
self.p_drop = p_drop
def forward(self, Z_raw, Z_mask_col):
Z = self.layernorm1(Z_raw)
left_right_proj_act = self.left_right_projection(Z)
left_right_proj_act = Z_mask_col.unsqueeze(-1) * left_right_proj_act
left_right_proj_act *= torch.sigmoid(self.left_right_gate(Z))
left_proj_act, right_proj_act = left_right_proj_act.chunk(2, dim=-1)
left_proj_act, work = gather_async(left_proj_act.contiguous(), dim=2)
g = torch.sigmoid(self.output_gate(Z))
right_proj_act = permute_final_dims(right_proj_act, (2, 0, 1))
left_proj_act = gather_async_opp(left_proj_act, work, dim=2)
p = torch.matmul(
permute_final_dims(left_proj_act, (2, 1, 0)),
right_proj_act
)
ab = permute_final_dims(p, (1, 2, 0))
# ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
dropout_mask,
Z_raw,
prob=self.p_drop)
class TriangleAttentionStartingNode(nn.Module):
def __init__(self, d_pair, p_drop, c=32, n_head=4):
super(TriangleAttentionStartingNode, self).__init__()
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernorm1 = LayerNorm(d_pair)
# _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
# std=1.0 / math.sqrt(d_pair))
# self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
self.linear_b = Linear(d_pair, n_head, initializer='linear', use_bias=False)
self.attention = SelfAttention(qkv_dim=d_pair,
c=c,
n_head=n_head,
out_dim=d_pair,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
def forward(self, Z_raw, Z_mask):
Z = self.layernorm1(Z_raw)
b = self.linear_b(Z)
# b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
# b = rearrange(b, 'b q k h -> b h q k')
# padding_bias = (1e9 * (Z_mask - 1.))[:, :, None, None, :]
Z = self.attention(Z, Z_mask, (b, work))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
class TriangleAttentionEndingNode(nn.Module):
def __init__(self, d_pair, p_drop, c=32, n_head=4):
super(TriangleAttentionEndingNode, self).__init__()
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernorm1 = LayerNorm(d_pair)
# _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
# std=1.0 / math.sqrt(d_pair))
# self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
self.linear_b = Linear(d_pair, n_head, initializer='linear', use_bias=False)
self.attention = SelfAttention(qkv_dim=d_pair,
c=c,
n_head=n_head,
out_dim=d_pair,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
def forward(self, Z_raw, Z_mask):
Z = Z_raw.transpose(-2, -3)
Z_mask = Z_mask.transpose(-1, -2)
Z = self.layernorm1(Z)
b = self.linear_b(Z)
# b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
# b = rearrange(b, 'b q k h -> b h q k')
# padding_bias = (1e9 * (Z_mask - 1.))[:, :, None, None, :]
Z = self.attention(Z, Z_mask, (b, work))
Z = Z.transpose(-2, -3)
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
class PairStack(nn.Module):
def __init__(self, d_pair, p_drop=0.25):
super(PairStack, self).__init__()
self.d_pair = d_pair
self.n_head = 4
self.hidden_c = int(d_pair / self.n_head)
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop, c=d_pair)
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop, c=d_pair)
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop, c=self.hidden_c, n_head=self.n_head)
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop, c=self.hidden_c, n_head=self.n_head)
self.PairTransition = Transition(d=d_pair)
def forward(self, pair, pair_mask):
pair_mask_row = scatter(pair_mask, dim=1)
pair_mask_col = scatter(pair_mask, dim=2)
pair = self.TriangleMultiplicationOutgoing(pair, pair_mask_row)
pair = row_to_col(pair)
pair = self.TriangleMultiplicationIncoming(pair, pair_mask_col)
pair = col_to_row(pair)
pair = self.TriangleAttentionStartingNode(pair, pair_mask_row)
pair = row_to_col(pair)
pair = self.TriangleAttentionEndingNode(pair, pair_mask_col)
pair = self.PairTransition(pair)
pair = col_to_row(pair)
return pair
einops
colossalai
\ No newline at end of file
import os
import subprocess
import sys
import torch
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1]
print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does " +
"not match the version used to compile Pytorch binaries. " +
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) +
"In some cases, a minor-version mismatch will not cause later errors: " +
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk).")
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
def fetch_requirements(path):
with open(path, 'r') as fd:
return [r.strip() for r in fd.readlines()]
if not torch.cuda.is_available():
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
# which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
print(
'\nWarning: Torch did not find available GPUs on this system.\n',
'If your intention is to cross-compile, this is not an error.\n'
'By default, FastFold will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
'Volta (compute capability 7.0), Turing (compute capability 7.5),\n'
'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
'If you wish to cross-compile for a single specific architecture,\n'
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 10):
raise RuntimeError("FastFold requires Pytorch 1.10 or newer.\n" +
"The latest stable release can be obtained from https://pytorch.org/")
cmdclass = {}
ext_modules = []
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
# and
# https://github.com/NVIDIA/apex/issues/456
# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
if "--cuda_ext" in sys.argv:
sys.argv.remove("--cuda_ext")
if CUDA_HOME is None:
raise RuntimeError(
"--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc."
)
else:
check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)
def cuda_ext_helper(name, sources, extra_cuda_flags):
return CUDAExtension(
name=name,
sources=[
os.path.join('fastfold/model/kernel/cuda_native/csrc', path) for path in sources
],
include_dirs=[
os.path.join(this_dir, 'fastfold/model/kernel/cuda_native/csrc')
],
extra_compile_args={
'cxx': ['-O3'] + version_dependent_macros,
'nvcc':
append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros +
extra_cuda_flags)
})
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
extra_cuda_flags = [
'-std=c++14', '-maxrregcount=50', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', '--expt-extended-lambda'
]
ext_modules.append(
cuda_ext_helper('fastfold_layer_norm_cuda',
['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'],
extra_cuda_flags + cc_flag))
ext_modules.append(
cuda_ext_helper('fastfold_softmax_cuda', ['softmax_cuda.cpp', 'softmax_cuda_kernel.cu'],
extra_cuda_flags + cc_flag))
install_requires = fetch_requirements('./requirements.txt')
setup(
name='fastfold',
version='0.0.1-beta',
packages=find_packages(exclude=(
'assets',
'benchmark',
'notebooks',
'scripts',
'*.egg-info',
)),
description=
'Optimizing Protein Structure Prediction Model Training and Inference on GPU Clusters',
ext_modules=ext_modules,
cmdclass={'build_ext': BuildExtension} if ext_modules else {},
install_requires=install_requires,
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment