"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "e466496785ef8990d996ebd9d321505fa42c0660"
Commit edb7d341 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add memory-efficient attention kernels

parent 816c1843
......@@ -26,6 +26,10 @@ OpenFold is equipped with an implementation of low-memory attention
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)), which
enables inference on extremely long chains.
We've modified FastFold's custom CUDA kernels to support in-place attention
during inference and training. These use 4x and 5x less GPU memory than
equivalent FastFold and stock PyTorch implementations, respectively.
We also make available efficient scripts for generating alignments. We've
used them to generate millions of alignments that will be released alongside
original OpenFold weights, trained from scratch using our code (more on that soon).
......@@ -57,6 +61,12 @@ To deactivate it, run:
source scripts/deactivate_conda_env.sh
```
With the environment active, compile OpenFold's CUDA kernels with
```bash
python3 setup.py install
```
To install the HH-suite to `/usr/bin`, run
```bash
......@@ -138,13 +148,6 @@ to `None` in the config.
### Training
After activating the OpenFold environment with
`source scripts/activate_conda_env.sh`, install OpenFold by running
```bash
python setup.py install
```
To train the model, you will first need to precompute protein alignments.
You have two options. You can use the same procedure DeepMind used by running
......
......@@ -368,6 +368,7 @@ class ExtraMSABlock(nn.Module):
z=z.clone() if torch.is_grad_enabled() else z,
mask=msa_mask,
chunk_size=chunk_size,
use_memory_efficient_kernel=not _chunk_logits,
_chunk_logits=_chunk_logits if torch.is_grad_enabled() else None,
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
......@@ -558,11 +559,14 @@ class ExtraMSAStack(nn.Module):
eps: float,
ckpt: bool,
clear_cache_between_blocks: bool = False,
chunk_msa_attn: bool = False,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
self.ckpt = ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self.chunk_msa_attn = chunk_msa_attn
self.blocks = nn.ModuleList()
for _ in range(no_blocks):
block = ExtraMSABlock(
......@@ -579,7 +583,7 @@ class ExtraMSAStack(nn.Module):
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
ckpt=ckpt,
ckpt=ckpt if chunk_msa_attn else False,
)
self.blocks.append(block)
......@@ -603,28 +607,36 @@ class ExtraMSAStack(nn.Module):
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
"""
if(not self.chunk_msa_attn):
checkpoint_fn = get_checkpoint_fn()
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_chunk_logits=None
) for b in self.blocks
]
def clear_cache(b, *args):
torch.cuda.empty_cache()
return b(*args)
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
if(self.clear_cache_between_blocks):
blocks = [partial(clear_cache, b) for b in blocks]
for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
for b in blocks:
if(self.ckpt and torch.is_grad_enabled()):
m, z = checkpoint_fn(b, *(m, z))
else:
m, z = b(m, z)
else:
for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z
......@@ -79,20 +79,30 @@ class MSAAttention(nn.Module):
)
self.mha = Attention(
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
self.c_in,
self.c_in,
self.c_in,
self.c_hidden,
self.no_heads,
)
@torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
biases: List[torch.Tensor],
use_memory_efficient_kernel: bool,
chunk_size: int,
) -> torch.Tensor:
return chunk_layer(
self.mha,
{"q_x": m, "kv_x": m, "biases": biases},
{
"q_x": m,
"kv_x": m,
"biases": biases,
"use_memory_efficient_kernel": use_memory_efficient_kernel,
},
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
no_batch_dims=len(m.shape[:-2])
)
def _prep_inputs(self,
......@@ -113,13 +123,6 @@ class MSAAttention(nn.Module):
# [*, N_seq, 1, 1, N_res]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# This step simply returns a larger view of the bias, and does not
# consume additional memory.
# [*, N_seq, no_heads, N_res, N_res]
#bias = bias.expand(
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
#)
if (self.pair_bias and
z is not None and # For the
self.layer_norm_z is not None and # benefit of
......@@ -144,6 +147,11 @@ class MSAAttention(nn.Module):
chunk_logits: int,
checkpoint: bool,
) -> torch.Tensor:
"""
MSA attention with training-time chunking of the softmax computation.
Saves memory in the extra MSA stack. Probably obviated by our fused
attention kernel, which is now used by default.
"""
MSA_DIM = -4
def _get_qkv(m, z):
......@@ -181,6 +189,7 @@ class MSAAttention(nn.Module):
z: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
_chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor:
......@@ -212,12 +221,13 @@ class MSAAttention(nn.Module):
biases.append(z)
if chunk_size is not None:
m = self._chunk(m, biases, chunk_size)
m = self._chunk(m, biases, use_memory_efficient_kernel, chunk_size)
else:
m = self.mha(
q_x=m,
kv_x=m,
biases=biases
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
)
return m
......@@ -291,7 +301,8 @@ class MSAColumnAttention(nn.Module):
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
) -> torch.Tensor:
"""
Args:
......
......@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import math
from typing import Optional, Callable, List, Tuple, Sequence
......@@ -24,6 +23,7 @@ import torch.nn as nn
from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.kernel.attention_core import attention_core
from openfold.utils.tensor_utils import (
permute_final_dims,
flatten_final_dims,
......@@ -199,8 +199,9 @@ class LayerNorm(nn.Module):
return out
@torch.jit.ignore
def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
......@@ -217,14 +218,8 @@ def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
#@torch.jit.script
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:
# [*, H, Q, C_hidden]
query = permute_final_dims(query, (1, 0, 2))
# [*, H, C_hidden, K]
key = permute_final_dims(key, (1, 2, 0))
# [*, H, V, C_hidden]
value = permute_final_dims(value, (1, 0, 2))
key = permute_final_dims(key, (1, 0))
# [*, H, Q, K]
a = torch.matmul(query, key)
......@@ -232,14 +227,11 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias
for b in biases:
a += b
a = softmax(a, -1)
a = softmax_no_cast(a, -1)
# [*, H, Q, C_hidden]
a = torch.matmul(a, value)
# [*, Q, H, C_hidden]
a = a.transpose(-2, -3)
return a
......@@ -254,7 +246,8 @@ def _attention_chunked_trainable(
def _checkpointable_attention(q, k, v, b1, b2):
bs = [b for b in [b1, b2] if b is not None]
return _attention(q, k, v, bs)
a = _attention(q, k, v, bs)
return a
o_chunks = []
checkpoint_fn = get_checkpoint_fn()
......@@ -289,7 +282,8 @@ def _attention_chunked_trainable(
]
o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
o_chunk = o_chunk.transpose(-2, -3)
o_chunks.append(o_chunk)
o = torch.cat(o_chunks, dim=chunk_dim)
......@@ -374,6 +368,11 @@ class Attention(nn.Module):
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
# [*, H, Q/K, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
q /= math.sqrt(self.c_hidden)
return q, k, v
......@@ -402,6 +401,7 @@ class Attention(nn.Module):
q_x: torch.Tensor,
kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
q_chunk_size: Optional[int] = None,
kv_chunk_size: Optional[int] = None,
......@@ -414,8 +414,15 @@ class Attention(nn.Module):
[*, K, C_k] key data
biases:
List of biases that broadcast to [*, H, Q, K]
use_memory_efficient_kernel:
Whether to use a custom memory-efficient attention kernel.
This should be the default choice for most. If none of the
"use_<...>" flags are True, a stock PyTorch implementation
is used instead
use_lma:
Whether to use low-memory attention
Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
q_chunk_size:
Query chunk size (for LMA)
kv_chunk_size:
......@@ -430,18 +437,32 @@ class Attention(nn.Module):
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
)
if(use_memory_efficient_kernel and use_lma):
raise ValueError(
"Choose one of use_memory_efficient_kernel and use_lma"
)
# [*, H, Q/K, C_hidden]
q, k, v = self._prep_qkv(q_x, kv_x)
if(use_lma):
# [*, Q, H, C_hidden]
if(use_memory_efficient_kernel):
if(len(biases) > 2):
raise ValueError(
"If use_memory_efficient_kernel is True, you may only "
"provide up to two bias terms"
)
o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
o = o.transpose(-2, -3)
elif(use_lma):
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases
]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
else:
o = _attention(q, k, v, biases)
o = o.transpose(-2, -3)
o = self._wrap_up(o, q_x)
......@@ -497,7 +518,7 @@ class GlobalAttention(nn.Module):
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = softmax(a)
a = softmax_no_cast(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
......
......@@ -2,12 +2,14 @@ import os
import glob
import importlib as importlib
from . import kernel
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
] + ["kernel"]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
......
import importlib
from functools import reduce
from operator import mul
import torch
attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
class AttentionCoreFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, bias_1=None, bias_2=None):
if(bias_1 is None and bias_2 is not None):
raise ValueError("bias_1 must be specified before bias_2")
q = q.contiguous()
k = k.contiguous()
# [*, H, Q, K]
attention_logits = torch.matmul(
q, k.transpose(-1, -2),
)
if(bias_1 is not None):
attention_logits += bias_1
if(bias_2 is not None):
attention_logits += bias_2
attn_core_inplace_cuda.forward_(
attention_logits,
reduce(mul, attention_logits.shape[:-1]),
attention_logits.shape[-1],
)
o = torch.matmul(attention_logits, v)
ctx.bias_1_shape = bias_1.shape if bias_1 is not None else None
ctx.bias_2_shape = bias_2.shape if bias_2 is not None else None
ctx.save_for_backward(q, k, v, attention_logits)
return o
@staticmethod
def backward(ctx, grad_output):
q, k, v, attention_logits = ctx.saved_tensors
grad_q = grad_k = grad_v = grad_bias_1 = grad_bias_2 = None
grad_v = torch.matmul(
attention_logits.transpose(-1, -2),
grad_output
)
attn_core_inplace_cuda.backward_(
attention_logits,
grad_output.contiguous(),
v.contiguous(), # v is implicitly transposed in the kernel
reduce(mul, attention_logits.shape[:-1]),
attention_logits.shape[-1],
grad_output.shape[-1],
)
if(ctx.bias_1_shape is not None):
grad_bias_1 = torch.sum(
attention_logits,
dim=tuple(i for i,d in enumerate(ctx.bias_1_shape) if d == 1),
keepdim=True,
)
if(ctx.bias_2_shape is not None):
grad_bias_2 = torch.sum(
attention_logits,
dim=tuple(i for i,d in enumerate(ctx.bias_2_shape) if d == 1),
keepdim=True,
)
grad_q = torch.matmul(
attention_logits, k
)
grad_k = torch.matmul(
q.transpose(-1, -2), attention_logits,
).transpose(-1, -2)
return grad_q, grad_k, grad_v, grad_bias_1, grad_bias_2
attention_core = AttentionCoreFunction.apply
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/extension.h>
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
void attn_softmax_inplace_forward_(
at::Tensor input,
long long rows, int cols
);
void attn_softmax_inplace_backward_(
at::Tensor output,
at::Tensor d_ov,
at::Tensor values,
long long rows,
int cols_output,
int cols_values
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"forward_",
&attn_softmax_inplace_forward_,
"Softmax forward (CUDA)"
);
m.def(
"backward_",
&attn_softmax_inplace_backward_,
"Softmax backward (CUDA)"
);
}
#include <math_constants.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
__inline__ __device__ float WarpAllReduceMax(float val) {
for (int mask = 1; mask < 32; mask *= 2) {
val = max(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
__inline__ __device__ float WarpAllReduceSum(float val) {
for (int mask = 1; mask < 32; mask *= 2) {
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}
template<typename T>
__global__ void attn_softmax_inplace_(
T *input,
long long rows, int cols
) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x);
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
T *row_input = input + row_offset * cols;
T *row_output = row_input;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
int idx = lane_id * cols_per_thread + i;
buf[i] = static_cast<float>(row_input[idx]);
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
row_output[lane_id * cols_per_thread + i] =
static_cast<T>(__fdividef(buf[i], warp_sum));
}
}
}
void attn_softmax_inplace_forward_(
at::Tensor input,
long long rows, int cols
) {
CHECK_INPUT(input);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
int grid = (rows + 3) / 4;
dim3 block(128);
if (input.dtype() == torch::kFloat32) {
attn_softmax_inplace_<float><<<grid, block>>>(
(float *)input.data_ptr(),
rows, cols
);
}
else {
attn_softmax_inplace_<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)input.data_ptr(),
rows, cols
);
}
}
template<typename T>
__global__ void attn_softmax_inplace_grad_(
//__global__ void attn_softmax_inplace_grad_bf16_(
T *output,
T *d_ov,
T *values,
long long rows,
int cols_output,
int cols_values
) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x);
int cols_per_thread = (cols_output + 31) / 32;
int cols_this_thread = cols_per_thread;
int rows_values = cols_output;
// values are set to the beginning of the current
// rows_values x cols_values leaf matrix
long long value_row_offset = row_offset - row_offset % rows_values;
int last_y = (cols_output / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols_output - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
T *row_output = output + row_offset * cols_output;
T *row_d_ov = d_ov + row_offset * cols_values;
T *row_values = values + value_row_offset * cols_values;
float thread_max = -1 * CUDART_INF_F;
// Compute a chunk of the output gradient on the fly
int value_row_idx = 0;
int value_idx = 0;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
T sum = 0.;
#pragma unroll
for (int j = 0; j < cols_values; j++) {
value_row_idx = ((lane_id * cols_per_thread) + i);
value_idx = value_row_idx * cols_values + j;
sum += row_d_ov[j] * row_values[value_idx];
}
dy_buf[i] = static_cast<float>(sum);
}
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]);
}
float thread_sum = 0.;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
row_output[lane_id * cols_per_thread + i] = static_cast<T>(
(dy_buf[i] - warp_sum) * y_buf[i]
);
}
}
}
void attn_softmax_inplace_backward_(
at::Tensor output,
at::Tensor d_ov,
at::Tensor values,
long long rows,
int cols_output,
int cols_values
) {
CHECK_INPUT(output);
CHECK_INPUT(d_ov);
CHECK_INPUT(values);
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
int grid = (rows + 3) / 4;
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
attn_softmax_inplace_grad_<float><<<grid, block>>>(
(float *)output.data_ptr(),
(float *)d_ov.data_ptr(),
(float *)values.data_ptr(),
rows, cols_output, cols_values
);
} else {
attn_softmax_inplace_grad_<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)output.data_ptr(),
(at::BFloat16 *)d_ov.data_ptr(),
(at::BFloat16 *)values.data_ptr(),
rows, cols_output, cols_values
);
}
}
......@@ -12,8 +12,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import find_packages
from setuptools import setup
import os
from setuptools import setup, Extension, find_packages
import subprocess
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
version_dependent_macros = [
'-DVERSION_GE_1_1',
'-DVERSION_GE_1_3',
'-DVERSION_GE_1_5',
]
extra_cuda_flags = [
'-std=c++14',
'-maxrregcount=50',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda'
]
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
extra_cuda_flags += cc_flag
setup(
name='openfold',
......@@ -25,7 +63,32 @@ setup(
url='https://github.com/aqlaboratory/openfold',
packages=find_packages(exclude=["tests", "scripts"]),
include_package_data=True,
package_data={"": ["resources/stereo_chemical_props.txt"]},
package_data={
"openfold": ['utils/kernel/csrc/*'],
"": ["resources/stereo_chemical_props.txt"]
},
ext_modules=[CUDAExtension(
name="attn_core_inplace_cuda",
sources=[
"openfold/utils/kernel/csrc/softmax_cuda.cpp",
"openfold/utils/kernel/csrc/softmax_cuda_kernel.cu",
],
include_dirs=[
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'openfold/utils/kernel/csrc/'
)
],
extra_compile_args={
'cxx': ['-O3'] + version_dependent_macros,
'nvcc': (
['-O3', '--use_fast_math'] +
version_dependent_macros +
extra_cuda_flags
),
}
)],
cmdclass={'build_ext': BuildExtension},
install_requires=[
'torch',
'deepspeed',
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import unittest
from openfold.model.primitives import _attention
from openfold.utils.kernel.attention_core import attention_core
from tests.config import consts
class TestAttentionCore(unittest.TestCase):
def test_attention_core_forward(self):
n_res = consts.n_res
h = consts.n_heads_extra_msa
n_seq = consts.n_extra
c = consts.c_e
dtype = torch.float32
q = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda()
k = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda()
v = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda()
mask = torch.randint(0, 2, [n_seq, n_res]).cuda()
mask_bias = (1e9 * mask - 1)[..., None, None, :].to(dtype)
out_repro = attention_core(q, k, v, mask_bias, None)
out_gt = _attention(q, k, v, [mask_bias])
self.assertTrue(torch.max(torch.abs(out_repro - out_gt)) < consts.eps)
def test_attention_core_backward(self):
n_res = consts.n_res
h = consts.n_heads_extra_msa
n_seq = consts.n_extra
c = consts.c_e
dtype = torch.float32
q = torch.rand(
[n_seq, h, n_res, c], dtype=dtype, requires_grad=True
).cuda()
k = torch.rand(
[n_seq, h, n_res, c], dtype=dtype, requires_grad=True
).cuda()
v = torch.rand(
[n_seq, h, n_res, c], dtype=dtype, requires_grad=True
).cuda()
mask = torch.randint(0, 2, [n_seq, n_res]).cuda()
mask_bias = (1e9 * mask - 1)[..., None, None, :].to(dtype)
def clone(t):
t = t.clone()
if(t.requires_grad):
t.retain_grad()
return t
q_repro = clone(q)
k_repro = clone(k)
v_repro = clone(v)
out_repro = attention_core(
q_repro, k_repro, v_repro, mask_bias, None
)
loss_repro = torch.mean(out_repro)
loss_repro.backward()
q_gt = clone(q)
k_gt = clone(k)
v_gt = clone(v)
out_gt = _attention(
q_gt, k_gt, v_gt, [mask_bias]
)
loss_gt = torch.mean(out_gt)
loss_gt.backward()
pairs = zip([q_repro, k_repro, v_repro], [q_gt, k_gt, v_gt])
for t_repro, t_gt in pairs:
self.assertTrue(
torch.max(torch.abs(t_repro.grad - t_gt.grad)) < consts.eps
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment