"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "e685140975ca203242e826e13b088654509d6620"
Commit f0a320e0 authored by Christina Floristean's avatar Christina Floristean
Browse files

Integrated deepspeed attention kernel and added initial tests.

parent 2134cc09
......@@ -10,9 +10,18 @@
"bfloat16": {
"enabled": true
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3,
"eps": 1e-5
}
},
"zero_optimization": {
"stage": 2,
"cpu_offload": true,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true
},
"activation_checkpointing": {
......
......@@ -367,6 +367,7 @@ config = mlc.ConfigDict(
"globals": {
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"use_deepspeed_evo_attention": False,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash.
"use_lma": False,
......
......@@ -181,6 +181,7 @@ class EvoformerBlockCore(nn.Module):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
......@@ -260,6 +261,7 @@ class EvoformerBlockCore(nn.Module):
mask=pair_mask,
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
......@@ -279,6 +281,7 @@ class EvoformerBlockCore(nn.Module):
mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
......@@ -365,6 +368,7 @@ class EvoformerBlock(nn.Module):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
......@@ -392,6 +396,7 @@ class EvoformerBlock(nn.Module):
mask=msa_mask,
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
)
),
......@@ -403,6 +408,7 @@ class EvoformerBlock(nn.Module):
m,
mask=msa_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
),
......@@ -418,7 +424,8 @@ class EvoformerBlock(nn.Module):
input_tensors,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
......@@ -494,6 +501,7 @@ class ExtraMSABlock(nn.Module):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
......@@ -520,7 +528,8 @@ class ExtraMSABlock(nn.Module):
mask=msa_mask,
chunk_size=_attn_chunk_size,
use_lma=use_lma,
use_memory_efficient_kernel=not use_lma,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_memory_efficient_kernel=not (use_lma or use_deepspeed_evo_attention),
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
......@@ -554,6 +563,7 @@ class ExtraMSABlock(nn.Module):
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
......@@ -674,6 +684,7 @@ class EvoformerStack(nn.Module):
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool,
use_lma: bool,
use_flash: bool,
msa_mask: Optional[torch.Tensor],
......@@ -687,6 +698,7 @@ class EvoformerStack(nn.Module):
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
inplace_safe=inplace_safe,
......@@ -726,6 +738,7 @@ class EvoformerStack(nn.Module):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
_mask_trans: bool = True,
......@@ -737,6 +750,7 @@ class EvoformerStack(nn.Module):
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask,
......@@ -768,6 +782,7 @@ class EvoformerStack(nn.Module):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
......@@ -802,6 +817,7 @@ class EvoformerStack(nn.Module):
m=m,
z=z,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask,
......@@ -882,6 +898,7 @@ class ExtraMSAStack(nn.Module):
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool,
use_lma: bool,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
......@@ -893,7 +910,8 @@ class ExtraMSAStack(nn.Module):
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
......@@ -930,6 +948,7 @@ class ExtraMSAStack(nn.Module):
def _forward_offload(self,
input_tensors: Sequence[torch.Tensor],
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
......@@ -942,6 +961,7 @@ class ExtraMSAStack(nn.Module):
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
......@@ -968,6 +988,7 @@ class ExtraMSAStack(nn.Module):
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
......@@ -992,6 +1013,7 @@ class ExtraMSAStack(nn.Module):
m=m,
z=z,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
......
......@@ -355,6 +355,7 @@ class AlphaFold(nn.Module):
input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
......@@ -367,6 +368,7 @@ class AlphaFold(nn.Module):
a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe,
......@@ -385,6 +387,7 @@ class AlphaFold(nn.Module):
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
......@@ -397,6 +400,7 @@ class AlphaFold(nn.Module):
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
inplace_safe=inplace_safe,
......
......@@ -91,7 +91,8 @@ class MSAAttention(nn.Module):
m: torch.Tensor,
biases: Optional[List[torch.Tensor]],
chunk_size: int,
use_memory_efficient_kernel: bool,
use_memory_efficient_kernel: bool,
use_deepspeed_evo_attention: bool,
use_lma: bool,
use_flash: bool,
flash_mask: Optional[torch.Tensor],
......@@ -103,6 +104,7 @@ class MSAAttention(nn.Module):
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=flash_mask,
......@@ -221,6 +223,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
......@@ -267,7 +270,8 @@ class MSAAttention(nn.Module):
m,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
......@@ -279,6 +283,7 @@ class MSAAttention(nn.Module):
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
......@@ -356,6 +361,7 @@ class MSAColumnAttention(nn.Module):
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
) -> torch.Tensor:
......@@ -378,7 +384,8 @@ class MSAColumnAttention(nn.Module):
m = self._msa_att(
m,
mask=mask,
chunk_size=chunk_size,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
use_flash=use_flash,
)
......
......@@ -12,20 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import importlib
import math
from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed):
if deepspeed_is_installed:
import deepspeed
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(fa_is_installed):
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attention import FlashAttention
if fa_is_installed:
from flash_attn.bert_padding import unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
import torch
......@@ -33,7 +32,6 @@ import torch.nn as nn
from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import _chunk_slice
from openfold.utils.kernel.attention_core import attention_core
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import (
......@@ -42,8 +40,8 @@ from openfold.utils.tensor_utils import (
)
DEFAULT_LMA_Q_CHUNK_SIZE=1024
DEFAULT_LMA_KV_CHUNK_SIZE=4096
DEFAULT_LMA_Q_CHUNK_SIZE = 1024
DEFAULT_LMA_KV_CHUNK_SIZE = 4096
def _prod(nums):
......@@ -196,9 +194,9 @@ class LayerNorm(nn.Module):
d = x.dtype
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
deepspeed.comm.comm.is_initialized()
)
if(d is torch.bfloat16 and not deepspeed_is_initialized):
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm(
x,
......@@ -228,9 +226,9 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
d = t.dtype
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
deepspeed.comm.comm.is_initialized()
)
if(d is torch.bfloat16 and not deepspeed_is_initialized):
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
else:
......@@ -262,7 +260,7 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias
def _attention_chunked_trainable(
query, key, value, biases, chunk_size, chunk_dim, checkpoint,
):
if(checkpoint and len(biases) > 2):
if checkpoint and len(biases) > 2:
raise ValueError(
"Checkpointed version permits only permits two bias terms"
)
......@@ -290,7 +288,7 @@ def _attention_chunked_trainable(
)
return b[tuple(idx)]
if(checkpoint):
if checkpoint:
bias_1_chunk, bias_2_chunk = [
_slice_bias(b) if b is not None else None
for b in (biases + [None, None])[:2]
......@@ -404,7 +402,7 @@ class Attention(nn.Module):
o: torch.Tensor,
q_x: torch.Tensor
) -> torch.Tensor:
if(self.linear_g is not None):
if self.linear_g is not None:
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
......@@ -425,11 +423,12 @@ class Attention(nn.Module):
kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
use_flash: bool = False,
flash_mask: Optional[torch.Tensor] = None,
flash_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
......@@ -444,6 +443,10 @@ class Attention(nn.Module):
This should be the default choice for most. If none of the
"use_<...>" flags are True, a stock PyTorch implementation
is used instead
use_deepspeed_evo_attention:
Whether to use DeepSpeed memory-efficient attention kernel.
If none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
use_lma:
Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
......@@ -455,25 +458,25 @@ class Attention(nn.Module):
Returns
[*, Q, C_q] attention update
"""
if(use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None)):
if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
raise ValueError(
"If use_lma is specified, lma_q_chunk_size and "
"lma_kv_chunk_size must be provided"
)
if(use_flash and biases is not None):
if use_flash and biases is not None:
raise ValueError(
"use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead"
)
attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
if(sum(attn_options) > 1):
attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, use_flash]
if sum(attn_options) > 1:
raise ValueError(
"Choose at most one alternative attention algorithm"
)
if(biases is None):
if biases is None:
biases = []
# [*, H, Q/K, C_hidden]
......@@ -483,22 +486,47 @@ class Attention(nn.Module):
if is_fp16_enabled():
use_memory_efficient_kernel = False
if(use_memory_efficient_kernel):
if(len(biases) > 2):
if use_memory_efficient_kernel:
if len(biases) > 2:
raise ValueError(
"If use_memory_efficient_kernel is True, you may only "
"provide up to two bias terms"
)
o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
o = o.transpose(-2, -3)
elif(use_lma):
elif use_deepspeed_evo_attention:
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
add_batch_dim = len(q.shape) < 5
if add_batch_dim:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
biases = [b.unsqueeze(0) for b in biases]
orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
k.to(dtype=torch.bfloat16),
v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases])
o = o.to(dtype=orig_dtype)
else:
o = DS4Sci_EvoformerAttention(q, k, v, biases)
if add_batch_dim:
o = o.squeeze(0)
elif use_lma:
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases
]
o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
o = o.transpose(-2, -3)
elif(use_flash):
elif use_flash:
o = _flash_attn(q, k, v, flash_mask)
else:
o = _attention(q, k, v, biases)
......@@ -556,7 +584,7 @@ class GlobalAttention(nn.Module):
v = self.linear_v(m)
bias = (self.inf * (mask - 1))[..., :, None, :]
if(not use_lma):
if not use_lma:
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
......@@ -662,7 +690,7 @@ def _lma(
@torch.jit.ignore
def _flash_attn(q, k, v, kv_mask):
if(not fa_is_installed):
if not fa_is_installed:
raise ValueError(
"_flash_attn requires that FlashAttention be installed"
)
......@@ -714,8 +742,8 @@ def _flash_attn(q, k, v, kv_mask):
kv_cu_seqlens,
q_max_s,
kv_max_s,
dropout_p = 0.,
softmax_scale = 1., # q has been scaled already
dropout_p=0.,
softmax_scale=1., # q has been scaled already
)
# [*, B, N, H, C]
......
......@@ -63,6 +63,7 @@ class TriangleAttention(nn.Module):
biases: List[torch.Tensor],
chunk_size: int,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
......@@ -77,6 +78,7 @@ class TriangleAttention(nn.Module):
partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma
),
mha_inputs,
......@@ -90,6 +92,7 @@ class TriangleAttention(nn.Module):
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
......@@ -130,6 +133,7 @@ class TriangleAttention(nn.Module):
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
......@@ -139,6 +143,7 @@ class TriangleAttention(nn.Module):
kv_x=x,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma
)
......
......@@ -3,7 +3,7 @@ import ml_collections as mlc
consts = mlc.ConfigDict(
{
"batch_size": 2,
"n_res": 11,
"n_res": 20,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
......
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import unittest
import numpy as np
import pickle
from openfold.model.primitives import (
Attention,
)
from tests.config import consts
import tests.compare_utils as compare_utils
from openfold.data import data_transforms
from openfold.utils.tensor_utils import tensor_tree_map
class TestDeepSpeedKernel(unittest.TestCase):
def test_ds_kernel_vs_attention(self):
batch_size = consts.batch_size
c_hidden = 32
n = 2 ** 12
n_seq = 12
no_heads = 4
q = torch.rand(batch_size, n_seq, n, c_hidden).cuda()
kv = torch.rand(batch_size, n_seq, n, c_hidden).cuda()
bias = [torch.rand(batch_size, n_seq, 1, 1, n), torch.rand(batch_size, 1, no_heads, n, n)]
bias = [b.cuda() for b in bias]
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
l = a(q, kv, biases=bias, use_deepspeed_evo_attention=True)
real = a(q, kv, biases=bias)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
def compare_evoformer(self, dtype):
n_res = consts.n_res
n_seq = consts.n_seq
activations = {
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
"pair": torch.rand(n_res, n_res, consts.c_z, device='cuda', dtype=dtype)
}
masks = {
"msa": torch.randint(0, 2, (n_seq, n_res), device='cuda', dtype=dtype),
"pair": torch.randint(0, 2, (n_res, n_res), device='cuda', dtype=dtype),
}
with torch.cuda.amp.autocast(dtype=dtype):
model = compare_utils.get_global_pretrained_openfold()
out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
activations["msa"],
activations["pair"],
masks["msa"],
masks["pair"],
use_deepspeed_evo_attention=False,
chunk_size=4,
_mask_trans=False,
inplace_safe=False,
)
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
out_repro_msa_ds, out_repro_pair_ds = model.evoformer.blocks[0](
activations["msa"],
activations["pair"],
masks["msa"],
masks["pair"],
use_deepspeed_evo_attention=True,
chunk_size=4,
_mask_trans=False,
inplace_safe=False,
)
out_repro_msa_ds = out_repro_msa_ds.cpu()
out_repro_pair_ds = out_repro_pair_ds.cpu()
self.assertTrue(torch.allclose(torch.abs(out_repro_msa), torch.abs(out_repro_msa_ds), atol=consts.eps))
self.assertTrue(torch.allclose(torch.abs(out_repro_pair), torch.abs(out_repro_pair_ds), atol=consts.eps))
@compare_utils.skip_unless_alphafold_installed()
def test_compare_evoformer_bf16(self):
self.compare_evoformer(torch.bfloat16)
@compare_utils.skip_unless_alphafold_installed()
def test_compare_evoformer_fp32(self):
self.compare_evoformer(torch.float32)
@compare_utils.skip_unless_alphafold_installed()
def test_dry_run(self):
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp)
# atom37_to_atom14 doesn't like batches
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
batch["no_recycling_iters"] = np.array([3., 3., 3., 3., ])
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
batch["aatype"] = batch["aatype"].long()
batch["template_aatype"] = batch["template_aatype"].long()
batch["extra_msa"] = batch["extra_msa"].long()
batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14"
].long()
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch)
)
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = True
out_repro = model(batch)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment