Commit f0a320e0 authored by Christina Floristean's avatar Christina Floristean
Browse files

Integrated deepspeed attention kernel and added initial tests.

parent 2134cc09
...@@ -10,9 +10,18 @@ ...@@ -10,9 +10,18 @@
"bfloat16": { "bfloat16": {
"enabled": true "enabled": true
}, },
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3,
"eps": 1e-5
}
},
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"cpu_offload": true, "offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true "contiguous_gradients": true
}, },
"activation_checkpointing": { "activation_checkpointing": {
......
...@@ -367,6 +367,7 @@ config = mlc.ConfigDict( ...@@ -367,6 +367,7 @@ config = mlc.ConfigDict(
"globals": { "globals": {
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
"use_deepspeed_evo_attention": False,
# Use Staats & Rabe's low-memory attention algorithm. Mutually # Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash. # exclusive with use_flash.
"use_lma": False, "use_lma": False,
......
...@@ -181,6 +181,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -181,6 +181,7 @@ class EvoformerBlockCore(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -260,6 +261,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -260,6 +261,7 @@ class EvoformerBlockCore(nn.Module):
mask=pair_mask, mask=pair_mask,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False, use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -279,6 +281,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -279,6 +281,7 @@ class EvoformerBlockCore(nn.Module):
mask=pair_mask.transpose(-1, -2), mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False, use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -365,6 +368,7 @@ class EvoformerBlock(nn.Module): ...@@ -365,6 +368,7 @@ class EvoformerBlock(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
...@@ -392,6 +396,7 @@ class EvoformerBlock(nn.Module): ...@@ -392,6 +396,7 @@ class EvoformerBlock(nn.Module):
mask=msa_mask, mask=msa_mask,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False, use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
) )
), ),
...@@ -403,6 +408,7 @@ class EvoformerBlock(nn.Module): ...@@ -403,6 +408,7 @@ class EvoformerBlock(nn.Module):
m, m,
mask=msa_mask, mask=msa_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
), ),
...@@ -418,7 +424,8 @@ class EvoformerBlock(nn.Module): ...@@ -418,7 +424,8 @@ class EvoformerBlock(nn.Module):
input_tensors, input_tensors,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
...@@ -494,6 +501,7 @@ class ExtraMSABlock(nn.Module): ...@@ -494,6 +501,7 @@ class ExtraMSABlock(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -520,7 +528,8 @@ class ExtraMSABlock(nn.Module): ...@@ -520,7 +528,8 @@ class ExtraMSABlock(nn.Module):
mask=msa_mask, mask=msa_mask,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_lma=use_lma, use_lma=use_lma,
use_memory_efficient_kernel=not use_lma, use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_memory_efficient_kernel=not (use_lma or use_deepspeed_evo_attention),
_checkpoint_chunks= _checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False, self.ckpt if torch.is_grad_enabled() else False,
) )
...@@ -554,6 +563,7 @@ class ExtraMSABlock(nn.Module): ...@@ -554,6 +563,7 @@ class ExtraMSABlock(nn.Module):
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
...@@ -674,6 +684,7 @@ class EvoformerStack(nn.Module): ...@@ -674,6 +684,7 @@ class EvoformerStack(nn.Module):
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool,
use_lma: bool, use_lma: bool,
use_flash: bool, use_flash: bool,
msa_mask: Optional[torch.Tensor], msa_mask: Optional[torch.Tensor],
...@@ -687,6 +698,7 @@ class EvoformerStack(nn.Module): ...@@ -687,6 +698,7 @@ class EvoformerStack(nn.Module):
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
...@@ -726,6 +738,7 @@ class EvoformerStack(nn.Module): ...@@ -726,6 +738,7 @@ class EvoformerStack(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -737,6 +750,7 @@ class EvoformerStack(nn.Module): ...@@ -737,6 +750,7 @@ class EvoformerStack(nn.Module):
m=input_tensors[0], m=input_tensors[0],
z=input_tensors[1], z=input_tensors[1],
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
msa_mask=msa_mask, msa_mask=msa_mask,
...@@ -768,6 +782,7 @@ class EvoformerStack(nn.Module): ...@@ -768,6 +782,7 @@ class EvoformerStack(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
...@@ -802,6 +817,7 @@ class EvoformerStack(nn.Module): ...@@ -802,6 +817,7 @@ class EvoformerStack(nn.Module):
m=m, m=m,
z=z, z=z,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
msa_mask=msa_mask, msa_mask=msa_mask,
...@@ -882,6 +898,7 @@ class ExtraMSAStack(nn.Module): ...@@ -882,6 +898,7 @@ class ExtraMSAStack(nn.Module):
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool,
use_lma: bool, use_lma: bool,
msa_mask: Optional[torch.Tensor], msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor], pair_mask: Optional[torch.Tensor],
...@@ -893,7 +910,8 @@ class ExtraMSAStack(nn.Module): ...@@ -893,7 +910,8 @@ class ExtraMSAStack(nn.Module):
b, b,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
...@@ -930,6 +948,7 @@ class ExtraMSAStack(nn.Module): ...@@ -930,6 +948,7 @@ class ExtraMSAStack(nn.Module):
def _forward_offload(self, def _forward_offload(self,
input_tensors: Sequence[torch.Tensor], input_tensors: Sequence[torch.Tensor],
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None, msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None, pair_mask: Optional[torch.Tensor] = None,
...@@ -942,6 +961,7 @@ class ExtraMSAStack(nn.Module): ...@@ -942,6 +961,7 @@ class ExtraMSAStack(nn.Module):
m=input_tensors[0], m=input_tensors[0],
z=input_tensors[1], z=input_tensors[1],
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
...@@ -968,6 +988,7 @@ class ExtraMSAStack(nn.Module): ...@@ -968,6 +988,7 @@ class ExtraMSAStack(nn.Module):
msa_mask: Optional[torch.Tensor], msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor], pair_mask: Optional[torch.Tensor],
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -992,6 +1013,7 @@ class ExtraMSAStack(nn.Module): ...@@ -992,6 +1013,7 @@ class ExtraMSAStack(nn.Module):
m=m, m=m,
z=z, z=z,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
......
...@@ -355,6 +355,7 @@ class AlphaFold(nn.Module): ...@@ -355,6 +355,7 @@ class AlphaFold(nn.Module):
input_tensors, input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype), msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype), pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
...@@ -367,6 +368,7 @@ class AlphaFold(nn.Module): ...@@ -367,6 +368,7 @@ class AlphaFold(nn.Module):
a, z, a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype), msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype), pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
...@@ -385,6 +387,7 @@ class AlphaFold(nn.Module): ...@@ -385,6 +387,7 @@ class AlphaFold(nn.Module):
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype), msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype), pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -397,6 +400,7 @@ class AlphaFold(nn.Module): ...@@ -397,6 +400,7 @@ class AlphaFold(nn.Module):
msa_mask=msa_mask.to(dtype=m.dtype), msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype), pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash, use_flash=self.globals.use_flash,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
......
...@@ -91,7 +91,8 @@ class MSAAttention(nn.Module): ...@@ -91,7 +91,8 @@ class MSAAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
biases: Optional[List[torch.Tensor]], biases: Optional[List[torch.Tensor]],
chunk_size: int, chunk_size: int,
use_memory_efficient_kernel: bool, use_memory_efficient_kernel: bool,
use_deepspeed_evo_attention: bool,
use_lma: bool, use_lma: bool,
use_flash: bool, use_flash: bool,
flash_mask: Optional[torch.Tensor], flash_mask: Optional[torch.Tensor],
...@@ -103,6 +104,7 @@ class MSAAttention(nn.Module): ...@@ -103,6 +104,7 @@ class MSAAttention(nn.Module):
kv_x=m, kv_x=m,
biases=biases, biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
flash_mask=flash_mask, flash_mask=flash_mask,
...@@ -221,6 +223,7 @@ class MSAAttention(nn.Module): ...@@ -221,6 +223,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
...@@ -267,7 +270,8 @@ class MSAAttention(nn.Module): ...@@ -267,7 +270,8 @@ class MSAAttention(nn.Module):
m, m,
biases, biases,
chunk_size, chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
flash_mask=mask, flash_mask=mask,
...@@ -279,6 +283,7 @@ class MSAAttention(nn.Module): ...@@ -279,6 +283,7 @@ class MSAAttention(nn.Module):
kv_x=m, kv_x=m,
biases=biases, biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
flash_mask=mask, flash_mask=mask,
...@@ -356,6 +361,7 @@ class MSAColumnAttention(nn.Module): ...@@ -356,6 +361,7 @@ class MSAColumnAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -378,7 +384,8 @@ class MSAColumnAttention(nn.Module): ...@@ -378,7 +384,8 @@ class MSAColumnAttention(nn.Module):
m = self._msa_att( m = self._msa_att(
m, m,
mask=mask, mask=mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
) )
......
...@@ -12,20 +12,19 @@ ...@@ -12,20 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import importlib import importlib
import math import math
from typing import Optional, Callable, List, Tuple, Sequence from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np import numpy as np
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed): if deepspeed_is_installed:
import deepspeed import deepspeed
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
fa_is_installed = importlib.util.find_spec("flash_attn") is not None fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(fa_is_installed): if fa_is_installed:
from flash_attn.bert_padding import unpad_input, pad_input from flash_attn.bert_padding import unpad_input
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
import torch import torch
...@@ -33,7 +32,6 @@ import torch.nn as nn ...@@ -33,7 +32,6 @@ import torch.nn as nn
from scipy.stats import truncnorm from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import _chunk_slice
from openfold.utils.kernel.attention_core import attention_core from openfold.utils.kernel.attention_core import attention_core
from openfold.utils.precision_utils import is_fp16_enabled from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
...@@ -42,8 +40,8 @@ from openfold.utils.tensor_utils import ( ...@@ -42,8 +40,8 @@ from openfold.utils.tensor_utils import (
) )
DEFAULT_LMA_Q_CHUNK_SIZE=1024 DEFAULT_LMA_Q_CHUNK_SIZE = 1024
DEFAULT_LMA_KV_CHUNK_SIZE=4096 DEFAULT_LMA_KV_CHUNK_SIZE = 4096
def _prod(nums): def _prod(nums):
...@@ -196,9 +194,9 @@ class LayerNorm(nn.Module): ...@@ -196,9 +194,9 @@ class LayerNorm(nn.Module):
d = x.dtype d = x.dtype
deepspeed_is_initialized = ( deepspeed_is_initialized = (
deepspeed_is_installed and deepspeed_is_installed and
deepspeed.utils.is_initialized() deepspeed.comm.comm.is_initialized()
) )
if(d is torch.bfloat16 and not deepspeed_is_initialized): if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm( out = nn.functional.layer_norm(
x, x,
...@@ -228,9 +226,9 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: ...@@ -228,9 +226,9 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
d = t.dtype d = t.dtype
deepspeed_is_initialized = ( deepspeed_is_initialized = (
deepspeed_is_installed and deepspeed_is_installed and
deepspeed.utils.is_initialized() deepspeed.comm.comm.is_initialized()
) )
if(d is torch.bfloat16 and not deepspeed_is_initialized): if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim) s = torch.nn.functional.softmax(t, dim=dim)
else: else:
...@@ -262,7 +260,7 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias ...@@ -262,7 +260,7 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias
def _attention_chunked_trainable( def _attention_chunked_trainable(
query, key, value, biases, chunk_size, chunk_dim, checkpoint, query, key, value, biases, chunk_size, chunk_dim, checkpoint,
): ):
if(checkpoint and len(biases) > 2): if checkpoint and len(biases) > 2:
raise ValueError( raise ValueError(
"Checkpointed version permits only permits two bias terms" "Checkpointed version permits only permits two bias terms"
) )
...@@ -290,7 +288,7 @@ def _attention_chunked_trainable( ...@@ -290,7 +288,7 @@ def _attention_chunked_trainable(
) )
return b[tuple(idx)] return b[tuple(idx)]
if(checkpoint): if checkpoint:
bias_1_chunk, bias_2_chunk = [ bias_1_chunk, bias_2_chunk = [
_slice_bias(b) if b is not None else None _slice_bias(b) if b is not None else None
for b in (biases + [None, None])[:2] for b in (biases + [None, None])[:2]
...@@ -404,7 +402,7 @@ class Attention(nn.Module): ...@@ -404,7 +402,7 @@ class Attention(nn.Module):
o: torch.Tensor, o: torch.Tensor,
q_x: torch.Tensor q_x: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
if(self.linear_g is not None): if self.linear_g is not None:
g = self.sigmoid(self.linear_g(q_x)) g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
...@@ -425,11 +423,12 @@ class Attention(nn.Module): ...@@ -425,11 +423,12 @@ class Attention(nn.Module):
kv_x: torch.Tensor, kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None, biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE, lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE, lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
use_flash: bool = False, use_flash: bool = False,
flash_mask: Optional[torch.Tensor] = None, flash_mask: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -444,6 +443,10 @@ class Attention(nn.Module): ...@@ -444,6 +443,10 @@ class Attention(nn.Module):
This should be the default choice for most. If none of the This should be the default choice for most. If none of the
"use_<...>" flags are True, a stock PyTorch implementation "use_<...>" flags are True, a stock PyTorch implementation
is used instead is used instead
use_deepspeed_evo_attention:
Whether to use DeepSpeed memory-efficient attention kernel.
If none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
use_lma: use_lma:
Whether to use low-memory attention (Staats & Rabe 2021). If Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch none of the "use_<...>" flags are True, a stock PyTorch
...@@ -455,25 +458,25 @@ class Attention(nn.Module): ...@@ -455,25 +458,25 @@ class Attention(nn.Module):
Returns Returns
[*, Q, C_q] attention update [*, Q, C_q] attention update
""" """
if(use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None)): if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
raise ValueError( raise ValueError(
"If use_lma is specified, lma_q_chunk_size and " "If use_lma is specified, lma_q_chunk_size and "
"lma_kv_chunk_size must be provided" "lma_kv_chunk_size must be provided"
) )
if(use_flash and biases is not None): if use_flash and biases is not None:
raise ValueError( raise ValueError(
"use_flash is incompatible with the bias option. For masking, " "use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead" "use flash_mask instead"
) )
attn_options = [use_memory_efficient_kernel, use_lma, use_flash] attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, use_flash]
if(sum(attn_options) > 1): if sum(attn_options) > 1:
raise ValueError( raise ValueError(
"Choose at most one alternative attention algorithm" "Choose at most one alternative attention algorithm"
) )
if(biases is None): if biases is None:
biases = [] biases = []
# [*, H, Q/K, C_hidden] # [*, H, Q/K, C_hidden]
...@@ -483,22 +486,47 @@ class Attention(nn.Module): ...@@ -483,22 +486,47 @@ class Attention(nn.Module):
if is_fp16_enabled(): if is_fp16_enabled():
use_memory_efficient_kernel = False use_memory_efficient_kernel = False
if(use_memory_efficient_kernel): if use_memory_efficient_kernel:
if(len(biases) > 2): if len(biases) > 2:
raise ValueError( raise ValueError(
"If use_memory_efficient_kernel is True, you may only " "If use_memory_efficient_kernel is True, you may only "
"provide up to two bias terms" "provide up to two bias terms"
) )
o = attention_core(q, k, v, *((biases + [None] * 2)[:2])) o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
o = o.transpose(-2, -3) o = o.transpose(-2, -3)
elif(use_lma): elif use_deepspeed_evo_attention:
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
add_batch_dim = len(q.shape) < 5
if add_batch_dim:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
biases = [b.unsqueeze(0) for b in biases]
orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
k.to(dtype=torch.bfloat16),
v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases])
o = o.to(dtype=orig_dtype)
else:
o = DS4Sci_EvoformerAttention(q, k, v, biases)
if add_batch_dim:
o = o.squeeze(0)
elif use_lma:
biases = [ biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases for b in biases
] ]
o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size) o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
o = o.transpose(-2, -3) o = o.transpose(-2, -3)
elif(use_flash): elif use_flash:
o = _flash_attn(q, k, v, flash_mask) o = _flash_attn(q, k, v, flash_mask)
else: else:
o = _attention(q, k, v, biases) o = _attention(q, k, v, biases)
...@@ -556,7 +584,7 @@ class GlobalAttention(nn.Module): ...@@ -556,7 +584,7 @@ class GlobalAttention(nn.Module):
v = self.linear_v(m) v = self.linear_v(m)
bias = (self.inf * (mask - 1))[..., :, None, :] bias = (self.inf * (mask - 1))[..., :, None, :]
if(not use_lma): if not use_lma:
# [*, N_res, H, N_seq] # [*, N_res, H, N_seq]
a = torch.matmul( a = torch.matmul(
q, q,
...@@ -662,7 +690,7 @@ def _lma( ...@@ -662,7 +690,7 @@ def _lma(
@torch.jit.ignore @torch.jit.ignore
def _flash_attn(q, k, v, kv_mask): def _flash_attn(q, k, v, kv_mask):
if(not fa_is_installed): if not fa_is_installed:
raise ValueError( raise ValueError(
"_flash_attn requires that FlashAttention be installed" "_flash_attn requires that FlashAttention be installed"
) )
...@@ -714,8 +742,8 @@ def _flash_attn(q, k, v, kv_mask): ...@@ -714,8 +742,8 @@ def _flash_attn(q, k, v, kv_mask):
kv_cu_seqlens, kv_cu_seqlens,
q_max_s, q_max_s,
kv_max_s, kv_max_s,
dropout_p = 0., dropout_p=0.,
softmax_scale = 1., # q has been scaled already softmax_scale=1., # q has been scaled already
) )
# [*, B, N, H, C] # [*, B, N, H, C]
......
...@@ -63,6 +63,7 @@ class TriangleAttention(nn.Module): ...@@ -63,6 +63,7 @@ class TriangleAttention(nn.Module):
biases: List[torch.Tensor], biases: List[torch.Tensor],
chunk_size: int, chunk_size: int,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -77,6 +78,7 @@ class TriangleAttention(nn.Module): ...@@ -77,6 +78,7 @@ class TriangleAttention(nn.Module):
partial( partial(
self.mha, self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma use_lma=use_lma
), ),
mha_inputs, mha_inputs,
...@@ -90,6 +92,7 @@ class TriangleAttention(nn.Module): ...@@ -90,6 +92,7 @@ class TriangleAttention(nn.Module):
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -130,6 +133,7 @@ class TriangleAttention(nn.Module): ...@@ -130,6 +133,7 @@ class TriangleAttention(nn.Module):
biases, biases,
chunk_size, chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -139,6 +143,7 @@ class TriangleAttention(nn.Module): ...@@ -139,6 +143,7 @@ class TriangleAttention(nn.Module):
kv_x=x, kv_x=x,
biases=biases, biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma use_lma=use_lma
) )
......
...@@ -3,7 +3,7 @@ import ml_collections as mlc ...@@ -3,7 +3,7 @@ import ml_collections as mlc
consts = mlc.ConfigDict( consts = mlc.ConfigDict(
{ {
"batch_size": 2, "batch_size": 2,
"n_res": 11, "n_res": 20,
"n_seq": 13, "n_seq": 13,
"n_templ": 3, "n_templ": 3,
"n_extra": 17, "n_extra": 17,
......
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import unittest
import numpy as np
import pickle
from openfold.model.primitives import (
Attention,
)
from tests.config import consts
import tests.compare_utils as compare_utils
from openfold.data import data_transforms
from openfold.utils.tensor_utils import tensor_tree_map
class TestDeepSpeedKernel(unittest.TestCase):
def test_ds_kernel_vs_attention(self):
batch_size = consts.batch_size
c_hidden = 32
n = 2 ** 12
n_seq = 12
no_heads = 4
q = torch.rand(batch_size, n_seq, n, c_hidden).cuda()
kv = torch.rand(batch_size, n_seq, n, c_hidden).cuda()
bias = [torch.rand(batch_size, n_seq, 1, 1, n), torch.rand(batch_size, 1, no_heads, n, n)]
bias = [b.cuda() for b in bias]
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
l = a(q, kv, biases=bias, use_deepspeed_evo_attention=True)
real = a(q, kv, biases=bias)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
def compare_evoformer(self, dtype):
n_res = consts.n_res
n_seq = consts.n_seq
activations = {
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
"pair": torch.rand(n_res, n_res, consts.c_z, device='cuda', dtype=dtype)
}
masks = {
"msa": torch.randint(0, 2, (n_seq, n_res), device='cuda', dtype=dtype),
"pair": torch.randint(0, 2, (n_res, n_res), device='cuda', dtype=dtype),
}
with torch.cuda.amp.autocast(dtype=dtype):
model = compare_utils.get_global_pretrained_openfold()
out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
activations["msa"],
activations["pair"],
masks["msa"],
masks["pair"],
use_deepspeed_evo_attention=False,
chunk_size=4,
_mask_trans=False,
inplace_safe=False,
)
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
out_repro_msa_ds, out_repro_pair_ds = model.evoformer.blocks[0](
activations["msa"],
activations["pair"],
masks["msa"],
masks["pair"],
use_deepspeed_evo_attention=True,
chunk_size=4,
_mask_trans=False,
inplace_safe=False,
)
out_repro_msa_ds = out_repro_msa_ds.cpu()
out_repro_pair_ds = out_repro_pair_ds.cpu()
self.assertTrue(torch.allclose(torch.abs(out_repro_msa), torch.abs(out_repro_msa_ds), atol=consts.eps))
self.assertTrue(torch.allclose(torch.abs(out_repro_pair), torch.abs(out_repro_pair_ds), atol=consts.eps))
@compare_utils.skip_unless_alphafold_installed()
def test_compare_evoformer_bf16(self):
self.compare_evoformer(torch.bfloat16)
@compare_utils.skip_unless_alphafold_installed()
def test_compare_evoformer_fp32(self):
self.compare_evoformer(torch.float32)
@compare_utils.skip_unless_alphafold_installed()
def test_dry_run(self):
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp)
# atom37_to_atom14 doesn't like batches
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
batch["no_recycling_iters"] = np.array([3., 3., 3., 3., ])
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
batch["aatype"] = batch["aatype"].long()
batch["template_aatype"] = batch["template_aatype"].long()
batch["extra_msa"] = batch["extra_msa"].long()
batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14"
].long()
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch)
)
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = True
out_repro = model(batch)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment