Unverified Commit a13c0ce5 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #378 from aqlaboratory/deepspeed-evo-attention

Deepspeed evoformer attention
parents 2dc080ce 40d76358
...@@ -39,13 +39,14 @@ kernels support in-place attention during inference and training. They use ...@@ -39,13 +39,14 @@ kernels support in-place attention during inference and training. They use
implementations, respectively. implementations, respectively.
- **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments. - **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments.
- **FlashAttention** support greatly speeds up MSA attention. - **FlashAttention** support greatly speeds up MSA attention.
- **DeepSpeed DS4Sci_EvoformerAttention kernel** is a memory-efficient attention kernel developed as part of a collaboration between OpenFold and the DeepSpeed4Science initiative. The kernel provides substantial speedups for training and inference, and significantly reduces the model's peak device memory requirement by 13X. The model is 15% faster during the initial training and finetuning stages, and up to 4x faster during inference. To use this feature, simply set the `use_deepspeed_evo_attention` option in `openfold/config.py`.
## Installation (Linux) ## Installation (Linux)
All Python dependencies are specified in `environment.yml`. For producing sequence All Python dependencies are specified in `environment.yml`. For producing sequence
alignments, you'll also need `kalign`, the [HH-suite](https://github.com/soedinglab/hh-suite), alignments, you'll also need `kalign`, the [HH-suite](https://github.com/soedinglab/hh-suite),
and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (nightly build)} and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (nightly build)}
installed on on your system. You'll need `git-lfs` to download OpenFold parameters. installed on your system. You'll need `git-lfs` to download OpenFold parameters.
Finally, some download scripts require `aria2c` and `aws`. Finally, some download scripts require `aria2c` and `aws`.
This package is currently supported for CUDA 11 and Pytorch 1.12 This package is currently supported for CUDA 11 and Pytorch 1.12
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
}, },
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"cpu_offload": true, "offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true "contiguous_gradients": true
}, },
"activation_checkpointing": { "activation_checkpointing": {
...@@ -20,5 +22,6 @@ ...@@ -20,5 +22,6 @@
"cpu_checkpointing": false, "cpu_checkpointing": false,
"profile": false "profile": false
}, },
"gradient_clipping": 0.1 "gradient_clipping": 0.1,
"zero_force_ds_cpu_optimizer": false
} }
...@@ -30,7 +30,7 @@ dependencies: ...@@ -30,7 +30,7 @@ dependencies:
- bioconda::kalign2==2.04 - bioconda::kalign2==2.04
- pytorch::pytorch=1.12.* - pytorch::pytorch=1.12.*
- pip: - pip:
- deepspeed==0.5.10 - deepspeed==0.12.4
- dm-tree==0.1.6 - dm-tree==0.1.6
- git+https://github.com/NVIDIA/dllogger.git - git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/Dao-AILab/flash-attention.git@5b838a8 - git+https://github.com/Dao-AILab/flash-attention.git@5b838a8
...@@ -28,19 +28,28 @@ def enforce_config_constraints(config): ...@@ -28,19 +28,28 @@ def enforce_config_constraints(config):
( (
"globals.use_lma", "globals.use_lma",
"globals.use_flash", "globals.use_flash",
"globals.use_deepspeed_evo_attention"
), ),
] ]
for s1, s2 in mutually_exclusive_bools: for options in mutually_exclusive_bools:
s1_setting = string_to_setting(s1) option_settings = [string_to_setting(o) for o in options]
s2_setting = string_to_setting(s2) if sum(option_settings) > 1:
if(s1_setting and s2_setting): raise ValueError(f"Only one of {', '.join(options)} may be set at a time")
raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
fa_is_installed = importlib.util.find_spec("flash_attn") is not None fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(config.globals.use_flash and not fa_is_installed): if config.globals.use_flash and not fa_is_installed:
raise ValueError("use_flash requires that FlashAttention is installed") raise ValueError("use_flash requires that FlashAttention is installed")
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec(
"deepspeed.ops.deepspeed4science") is not None
if config.globals.use_deepspeed_evo_attention and not ds4s_is_installed:
raise ValueError(
"use_deepspeed_evo_attention requires that DeepSpeed be installed "
"and that the deepspeed.ops.deepspeed4science package exists"
)
if( if(
config.globals.offload_inference and config.globals.offload_inference and
not config.model.template.average_templates not config.model.template.average_templates
...@@ -193,7 +202,8 @@ def model_config( ...@@ -193,7 +202,8 @@ def model_config(
if long_sequence_inference: if long_sequence_inference:
assert(not train) assert(not train)
c.globals.offload_inference = True c.globals.offload_inference = True
c.globals.use_lma = True # Default to DeepSpeed memory-efficient attention kernel unless use_lma is explicitly set
c.globals.use_deepspeed_evo_attention = True if not c.globals.use_lma else False
c.globals.use_flash = False c.globals.use_flash = False
c.model.template.offload_inference = True c.model.template.offload_inference = True
c.model.template.template_pair_stack.tune_chunk_size = False c.model.template.template_pair_stack.tune_chunk_size = False
...@@ -419,11 +429,15 @@ config = mlc.ConfigDict( ...@@ -419,11 +429,15 @@ config = mlc.ConfigDict(
"seqemb_mode_enabled": False, # Global flag for enabling seq emb mode "seqemb_mode_enabled": False, # Global flag for enabling seq emb mode
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
# Use DeepSpeed memory-efficient attention kernel. Mutually
# exclusive with use_lma and use_flash.
"use_deepspeed_evo_attention": False,
# Use Staats & Rabe's low-memory attention algorithm. Mutually # Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash. # exclusive with use_deepspeed_evo_attention and use_flash.
"use_lma": False, "use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with # Use FlashAttention in selected modules. Mutually exclusive with
# use_lma. Doesn't work that well on long sequences (>1000 residues). # use_deepspeed_evo_attention and use_lma. Doesn't work that well
# on long sequences (>1000 residues).
"use_flash": False, "use_flash": False,
"offload_inference": False, "offload_inference": False,
"c_z": c_z, "c_z": c_z,
......
...@@ -87,7 +87,6 @@ class MSATransition(nn.Module): ...@@ -87,7 +87,6 @@ class MSATransition(nn.Module):
no_batch_dims=len(m.shape[:-2]), no_batch_dims=len(m.shape[:-2]),
) )
def forward( def forward(
self, self,
m: torch.Tensor, m: torch.Tensor,
...@@ -181,6 +180,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -181,6 +180,7 @@ class EvoformerBlockCore(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -260,6 +260,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -260,6 +260,7 @@ class EvoformerBlockCore(nn.Module):
mask=pair_mask, mask=pair_mask,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False, use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -279,6 +280,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -279,6 +280,7 @@ class EvoformerBlockCore(nn.Module):
mask=pair_mask.transpose(-1, -2), mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False, use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -339,7 +341,7 @@ class EvoformerBlock(nn.Module): ...@@ -339,7 +341,7 @@ class EvoformerBlock(nn.Module):
# Specifically, seqemb mode does not use column attention # Specifically, seqemb mode does not use column attention
self.no_column_attention = no_column_attention self.no_column_attention = no_column_attention
if self.no_column_attention == False: if not self.no_column_attention:
self.msa_att_col = MSAColumnAttention( self.msa_att_col = MSAColumnAttention(
c_m, c_m,
c_hidden_msa_att, c_hidden_msa_att,
...@@ -369,6 +371,7 @@ class EvoformerBlock(nn.Module): ...@@ -369,6 +371,7 @@ class EvoformerBlock(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
...@@ -396,6 +399,7 @@ class EvoformerBlock(nn.Module): ...@@ -396,6 +399,7 @@ class EvoformerBlock(nn.Module):
mask=msa_mask, mask=msa_mask,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False, use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
) )
), ),
...@@ -403,12 +407,13 @@ class EvoformerBlock(nn.Module): ...@@ -403,12 +407,13 @@ class EvoformerBlock(nn.Module):
) )
# Specifically, column attention is not used in seqemb mode. # Specifically, column attention is not used in seqemb mode.
if self.no_column_attention == False: if not self.no_column_attention:
m = add(m, m = add(m,
self.msa_att_col( self.msa_att_col(
m, m,
mask=msa_mask, mask=msa_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
), ),
...@@ -424,7 +429,8 @@ class EvoformerBlock(nn.Module): ...@@ -424,7 +429,8 @@ class EvoformerBlock(nn.Module):
input_tensors, input_tensors,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
...@@ -500,6 +506,7 @@ class ExtraMSABlock(nn.Module): ...@@ -500,6 +506,7 @@ class ExtraMSABlock(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -526,7 +533,8 @@ class ExtraMSABlock(nn.Module): ...@@ -526,7 +533,8 @@ class ExtraMSABlock(nn.Module):
mask=msa_mask, mask=msa_mask,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_lma=use_lma, use_lma=use_lma,
use_memory_efficient_kernel=not use_lma, use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_memory_efficient_kernel=not (use_lma or use_deepspeed_evo_attention),
_checkpoint_chunks= _checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False, self.ckpt if torch.is_grad_enabled() else False,
) )
...@@ -560,6 +568,7 @@ class ExtraMSABlock(nn.Module): ...@@ -560,6 +568,7 @@ class ExtraMSABlock(nn.Module):
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
...@@ -685,6 +694,7 @@ class EvoformerStack(nn.Module): ...@@ -685,6 +694,7 @@ class EvoformerStack(nn.Module):
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool,
use_lma: bool, use_lma: bool,
use_flash: bool, use_flash: bool,
msa_mask: Optional[torch.Tensor], msa_mask: Optional[torch.Tensor],
...@@ -698,6 +708,7 @@ class EvoformerStack(nn.Module): ...@@ -698,6 +708,7 @@ class EvoformerStack(nn.Module):
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
...@@ -737,6 +748,7 @@ class EvoformerStack(nn.Module): ...@@ -737,6 +748,7 @@ class EvoformerStack(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -748,6 +760,7 @@ class EvoformerStack(nn.Module): ...@@ -748,6 +760,7 @@ class EvoformerStack(nn.Module):
m=input_tensors[0], m=input_tensors[0],
z=input_tensors[1], z=input_tensors[1],
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
msa_mask=msa_mask, msa_mask=msa_mask,
...@@ -779,6 +792,7 @@ class EvoformerStack(nn.Module): ...@@ -779,6 +792,7 @@ class EvoformerStack(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
...@@ -797,10 +811,15 @@ class EvoformerStack(nn.Module): ...@@ -797,10 +811,15 @@ class EvoformerStack(nn.Module):
chunk_size: chunk_size:
Inference-time subbatch size. Acts as a minimum if Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference use_deepspeed_evo_attention:
Whether to use DeepSpeed memory efficient kernel.
Mutually exclusive with use_lma and use_flash.
use_lma:
Whether to use low-memory attention during inference.
Mutually exclusive with use_flash and use_deepspeed_evo_attention.
use_flash: use_flash:
Whether to use FlashAttention where possible. Mutually Whether to use FlashAttention where possible. Mutually
exclusive with use_lma. exclusive with use_lma and use_deepspeed_evo_attention.
Returns: Returns:
m: m:
[*, N_seq, N_res, C_m] MSA embedding [*, N_seq, N_res, C_m] MSA embedding
...@@ -813,6 +832,7 @@ class EvoformerStack(nn.Module): ...@@ -813,6 +832,7 @@ class EvoformerStack(nn.Module):
m=m, m=m,
z=z, z=z,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
msa_mask=msa_mask, msa_mask=msa_mask,
...@@ -893,6 +913,7 @@ class ExtraMSAStack(nn.Module): ...@@ -893,6 +913,7 @@ class ExtraMSAStack(nn.Module):
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool,
use_lma: bool, use_lma: bool,
msa_mask: Optional[torch.Tensor], msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor], pair_mask: Optional[torch.Tensor],
...@@ -904,7 +925,8 @@ class ExtraMSAStack(nn.Module): ...@@ -904,7 +925,8 @@ class ExtraMSAStack(nn.Module):
b, b,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
...@@ -941,6 +963,7 @@ class ExtraMSAStack(nn.Module): ...@@ -941,6 +963,7 @@ class ExtraMSAStack(nn.Module):
def _forward_offload(self, def _forward_offload(self,
input_tensors: Sequence[torch.Tensor], input_tensors: Sequence[torch.Tensor],
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None, msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None, pair_mask: Optional[torch.Tensor] = None,
...@@ -953,6 +976,7 @@ class ExtraMSAStack(nn.Module): ...@@ -953,6 +976,7 @@ class ExtraMSAStack(nn.Module):
m=input_tensors[0], m=input_tensors[0],
z=input_tensors[1], z=input_tensors[1],
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
...@@ -979,6 +1003,7 @@ class ExtraMSAStack(nn.Module): ...@@ -979,6 +1003,7 @@ class ExtraMSAStack(nn.Module):
msa_mask: Optional[torch.Tensor], msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor], pair_mask: Optional[torch.Tensor],
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -990,6 +1015,7 @@ class ExtraMSAStack(nn.Module): ...@@ -990,6 +1015,7 @@ class ExtraMSAStack(nn.Module):
z: z:
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules chunk_size: Inference-time subbatch size for Evoformer modules
use_deepspeed_evo_attention: Whether to use DeepSpeed memory-efficient kernel
use_lma: Whether to use low-memory attention during inference use_lma: Whether to use low-memory attention during inference
msa_mask: msa_mask:
Optional [*, N_extra, N_res] MSA mask Optional [*, N_extra, N_res] MSA mask
...@@ -1003,6 +1029,7 @@ class ExtraMSAStack(nn.Module): ...@@ -1003,6 +1029,7 @@ class ExtraMSAStack(nn.Module):
m=m, m=m,
z=z, z=z,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
......
...@@ -178,6 +178,7 @@ class AlphaFold(nn.Module): ...@@ -178,6 +178,7 @@ class AlphaFold(nn.Module):
t_pair, t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
...@@ -374,6 +375,7 @@ class AlphaFold(nn.Module): ...@@ -374,6 +375,7 @@ class AlphaFold(nn.Module):
input_tensors, input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype), msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype), pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
...@@ -386,6 +388,7 @@ class AlphaFold(nn.Module): ...@@ -386,6 +388,7 @@ class AlphaFold(nn.Module):
a, z, a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype), msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype), pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
...@@ -404,6 +407,7 @@ class AlphaFold(nn.Module): ...@@ -404,6 +407,7 @@ class AlphaFold(nn.Module):
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype), msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype), pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -416,6 +420,7 @@ class AlphaFold(nn.Module): ...@@ -416,6 +420,7 @@ class AlphaFold(nn.Module):
msa_mask=msa_mask.to(dtype=m.dtype), msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype), pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash, use_flash=self.globals.use_flash,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
......
...@@ -91,7 +91,8 @@ class MSAAttention(nn.Module): ...@@ -91,7 +91,8 @@ class MSAAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
biases: Optional[List[torch.Tensor]], biases: Optional[List[torch.Tensor]],
chunk_size: int, chunk_size: int,
use_memory_efficient_kernel: bool, use_memory_efficient_kernel: bool,
use_deepspeed_evo_attention: bool,
use_lma: bool, use_lma: bool,
use_flash: bool, use_flash: bool,
flash_mask: Optional[torch.Tensor], flash_mask: Optional[torch.Tensor],
...@@ -103,6 +104,7 @@ class MSAAttention(nn.Module): ...@@ -103,6 +104,7 @@ class MSAAttention(nn.Module):
kv_x=m, kv_x=m,
biases=biases, biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
flash_mask=flash_mask, flash_mask=flash_mask,
...@@ -221,6 +223,7 @@ class MSAAttention(nn.Module): ...@@ -221,6 +223,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
...@@ -267,7 +270,8 @@ class MSAAttention(nn.Module): ...@@ -267,7 +270,8 @@ class MSAAttention(nn.Module):
m, m,
biases, biases,
chunk_size, chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
flash_mask=mask, flash_mask=mask,
...@@ -279,6 +283,7 @@ class MSAAttention(nn.Module): ...@@ -279,6 +283,7 @@ class MSAAttention(nn.Module):
kv_x=m, kv_x=m,
biases=biases, biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
flash_mask=mask, flash_mask=mask,
...@@ -356,6 +361,7 @@ class MSAColumnAttention(nn.Module): ...@@ -356,6 +361,7 @@ class MSAColumnAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
use_flash: bool = False, use_flash: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -378,7 +384,8 @@ class MSAColumnAttention(nn.Module): ...@@ -378,7 +384,8 @@ class MSAColumnAttention(nn.Module):
m = self._msa_att( m = self._msa_att(
m, m,
mask=mask, mask=mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
use_flash=use_flash, use_flash=use_flash,
) )
......
...@@ -12,20 +12,22 @@ ...@@ -12,20 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import importlib import importlib
import math import math
from typing import Optional, Callable, List, Tuple, Sequence from typing import Optional, Callable, List, Tuple
import numpy as np import numpy as np
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed): ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None
if deepspeed_is_installed:
import deepspeed import deepspeed
if ds4s_is_installed:
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
fa_is_installed = importlib.util.find_spec("flash_attn") is not None fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(fa_is_installed): if fa_is_installed:
from flash_attn.bert_padding import unpad_input, pad_input from flash_attn.bert_padding import unpad_input
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
import torch import torch
...@@ -33,7 +35,6 @@ import torch.nn as nn ...@@ -33,7 +35,6 @@ import torch.nn as nn
from scipy.stats import truncnorm from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import _chunk_slice
from openfold.utils.kernel.attention_core import attention_core from openfold.utils.kernel.attention_core import attention_core
from openfold.utils.precision_utils import is_fp16_enabled from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
...@@ -42,8 +43,8 @@ from openfold.utils.tensor_utils import ( ...@@ -42,8 +43,8 @@ from openfold.utils.tensor_utils import (
) )
DEFAULT_LMA_Q_CHUNK_SIZE=1024 DEFAULT_LMA_Q_CHUNK_SIZE = 1024
DEFAULT_LMA_KV_CHUNK_SIZE=4096 DEFAULT_LMA_KV_CHUNK_SIZE = 4096
def _prod(nums): def _prod(nums):
...@@ -196,9 +197,9 @@ class LayerNorm(nn.Module): ...@@ -196,9 +197,9 @@ class LayerNorm(nn.Module):
d = x.dtype d = x.dtype
deepspeed_is_initialized = ( deepspeed_is_initialized = (
deepspeed_is_installed and deepspeed_is_installed and
deepspeed.utils.is_initialized() deepspeed.comm.comm.is_initialized()
) )
if(d is torch.bfloat16 and not deepspeed_is_initialized): if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm( out = nn.functional.layer_norm(
x, x,
...@@ -228,9 +229,9 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: ...@@ -228,9 +229,9 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
d = t.dtype d = t.dtype
deepspeed_is_initialized = ( deepspeed_is_initialized = (
deepspeed_is_installed and deepspeed_is_installed and
deepspeed.utils.is_initialized() deepspeed.comm.comm.is_initialized()
) )
if(d is torch.bfloat16 and not deepspeed_is_initialized): if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim) s = torch.nn.functional.softmax(t, dim=dim)
else: else:
...@@ -262,7 +263,7 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias ...@@ -262,7 +263,7 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias
def _attention_chunked_trainable( def _attention_chunked_trainable(
query, key, value, biases, chunk_size, chunk_dim, checkpoint, query, key, value, biases, chunk_size, chunk_dim, checkpoint,
): ):
if(checkpoint and len(biases) > 2): if checkpoint and len(biases) > 2:
raise ValueError( raise ValueError(
"Checkpointed version permits only permits two bias terms" "Checkpointed version permits only permits two bias terms"
) )
...@@ -290,7 +291,7 @@ def _attention_chunked_trainable( ...@@ -290,7 +291,7 @@ def _attention_chunked_trainable(
) )
return b[tuple(idx)] return b[tuple(idx)]
if(checkpoint): if checkpoint:
bias_1_chunk, bias_2_chunk = [ bias_1_chunk, bias_2_chunk = [
_slice_bias(b) if b is not None else None _slice_bias(b) if b is not None else None
for b in (biases + [None, None])[:2] for b in (biases + [None, None])[:2]
...@@ -377,7 +378,8 @@ class Attention(nn.Module): ...@@ -377,7 +378,8 @@ class Attention(nn.Module):
def _prep_qkv(self, def _prep_qkv(self,
q_x: torch.Tensor, q_x: torch.Tensor,
kv_x: torch.Tensor kv_x: torch.Tensor,
apply_scale: bool = True
) -> Tuple[ ) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor torch.Tensor, torch.Tensor, torch.Tensor
]: ]:
...@@ -396,7 +398,8 @@ class Attention(nn.Module): ...@@ -396,7 +398,8 @@ class Attention(nn.Module):
k = k.transpose(-2, -3) k = k.transpose(-2, -3)
v = v.transpose(-2, -3) v = v.transpose(-2, -3)
q /= math.sqrt(self.c_hidden) if apply_scale:
q /= math.sqrt(self.c_hidden)
return q, k, v return q, k, v
...@@ -404,7 +407,7 @@ class Attention(nn.Module): ...@@ -404,7 +407,7 @@ class Attention(nn.Module):
o: torch.Tensor, o: torch.Tensor,
q_x: torch.Tensor q_x: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
if(self.linear_g is not None): if self.linear_g is not None:
g = self.sigmoid(self.linear_g(q_x)) g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
...@@ -425,11 +428,12 @@ class Attention(nn.Module): ...@@ -425,11 +428,12 @@ class Attention(nn.Module):
kv_x: torch.Tensor, kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None, biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE, lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE, lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
use_flash: bool = False, use_flash: bool = False,
flash_mask: Optional[torch.Tensor] = None, flash_mask: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -444,6 +448,10 @@ class Attention(nn.Module): ...@@ -444,6 +448,10 @@ class Attention(nn.Module):
This should be the default choice for most. If none of the This should be the default choice for most. If none of the
"use_<...>" flags are True, a stock PyTorch implementation "use_<...>" flags are True, a stock PyTorch implementation
is used instead is used instead
use_deepspeed_evo_attention:
Whether to use DeepSpeed memory-efficient attention kernel.
If none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
use_lma: use_lma:
Whether to use low-memory attention (Staats & Rabe 2021). If Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch none of the "use_<...>" flags are True, a stock PyTorch
...@@ -455,50 +463,57 @@ class Attention(nn.Module): ...@@ -455,50 +463,57 @@ class Attention(nn.Module):
Returns Returns
[*, Q, C_q] attention update [*, Q, C_q] attention update
""" """
if(use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None)): if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
raise ValueError( raise ValueError(
"If use_lma is specified, lma_q_chunk_size and " "If use_lma is specified, lma_q_chunk_size and "
"lma_kv_chunk_size must be provided" "lma_kv_chunk_size must be provided"
) )
if(use_flash and biases is not None): if use_flash and biases is not None:
raise ValueError( raise ValueError(
"use_flash is incompatible with the bias option. For masking, " "use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead" "use flash_mask instead"
) )
attn_options = [use_memory_efficient_kernel, use_lma, use_flash] attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, use_flash]
if(sum(attn_options) > 1): if sum(attn_options) > 1:
raise ValueError( raise ValueError(
"Choose at most one alternative attention algorithm" "Choose at most one alternative attention algorithm"
) )
if(biases is None): if biases is None:
biases = [] biases = []
# [*, H, Q/K, C_hidden] # DeepSpeed attention kernel applies scaling internally
q, k, v = self._prep_qkv(q_x, kv_x) q, k, v = self._prep_qkv(q_x, kv_x,
apply_scale=not use_deepspeed_evo_attention)
# [*, Q, H, C_hidden]
if is_fp16_enabled(): if is_fp16_enabled():
use_memory_efficient_kernel = False use_memory_efficient_kernel = False
if(use_memory_efficient_kernel): if use_memory_efficient_kernel:
if(len(biases) > 2): if len(biases) > 2:
raise ValueError( raise ValueError(
"If use_memory_efficient_kernel is True, you may only " "If use_memory_efficient_kernel is True, you may only "
"provide up to two bias terms" "provide up to two bias terms"
) )
o = attention_core(q, k, v, *((biases + [None] * 2)[:2])) o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
o = o.transpose(-2, -3) o = o.transpose(-2, -3)
elif(use_lma): elif use_deepspeed_evo_attention:
if len(biases) > 2:
raise ValueError(
"If use_deepspeed_evo_attention is True, you may only "
"provide up to two bias terms"
)
o = _deepspeed_evo_attn(q, k, v, biases)
elif use_lma:
biases = [ biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases for b in biases
] ]
o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size) o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
o = o.transpose(-2, -3) o = o.transpose(-2, -3)
elif(use_flash): elif use_flash:
o = _flash_attn(q, k, v, flash_mask) o = _flash_attn(q, k, v, flash_mask)
else: else:
o = _attention(q, k, v, biases) o = _attention(q, k, v, biases)
...@@ -556,7 +571,7 @@ class GlobalAttention(nn.Module): ...@@ -556,7 +571,7 @@ class GlobalAttention(nn.Module):
v = self.linear_v(m) v = self.linear_v(m)
bias = (self.inf * (mask - 1))[..., :, None, :] bias = (self.inf * (mask - 1))[..., :, None, :]
if(not use_lma): if not use_lma:
# [*, N_res, H, N_seq] # [*, N_res, H, N_seq]
a = torch.matmul( a = torch.matmul(
q, q,
...@@ -598,6 +613,72 @@ class GlobalAttention(nn.Module): ...@@ -598,6 +613,72 @@ class GlobalAttention(nn.Module):
return m return m
@torch.jit.ignore
def _deepspeed_evo_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
biases: List[torch.Tensor],
):
"""""
Compute attention using the DeepSpeed DS4Sci_EvoformerAttention kernel.
Args:
q:
[*, H, Q, C_hidden] query data
k:
[*, H, K, C_hidden] key data
v:
[*, H, V, C_hidden] value data
biases:
List of biases that broadcast to [*, H, Q, K]
"""
if not ds4s_is_installed:
raise ValueError(
"_deepspeed_evo_attn requires that DeepSpeed be installed "
"and that the deepspeed.ops.deepspeed4science package exists"
)
def reshape_dims(x):
no_batch_dims = len(x.shape[:-3])
if no_batch_dims < 2:
return x.reshape(*((1,) * (2 - no_batch_dims) + x.shape))
if no_batch_dims > 2:
return x.reshape(*((x.shape[0], -1) + x.shape[-3:]))
return x
# [*, Q/K, H, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
# Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden]
# for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed.
orig_shape = q.shape
if len(orig_shape[:-3]) != 2:
q = reshape_dims(q)
k = reshape_dims(k)
v = reshape_dims(v)
biases = [reshape_dims(b) for b in biases]
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
k.to(dtype=torch.bfloat16),
v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases])
o = o.to(dtype=orig_dtype)
else:
o = DS4Sci_EvoformerAttention(q, k, v, biases)
o = o.reshape(orig_shape)
return o
def _lma( def _lma(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
...@@ -662,7 +743,7 @@ def _lma( ...@@ -662,7 +743,7 @@ def _lma(
@torch.jit.ignore @torch.jit.ignore
def _flash_attn(q, k, v, kv_mask): def _flash_attn(q, k, v, kv_mask):
if(not fa_is_installed): if not fa_is_installed:
raise ValueError( raise ValueError(
"_flash_attn requires that FlashAttention be installed" "_flash_attn requires that FlashAttention be installed"
) )
...@@ -714,8 +795,8 @@ def _flash_attn(q, k, v, kv_mask): ...@@ -714,8 +795,8 @@ def _flash_attn(q, k, v, kv_mask):
kv_cu_seqlens, kv_cu_seqlens,
q_max_s, q_max_s,
kv_max_s, kv_max_s,
dropout_p = 0., dropout_p=0.,
softmax_scale = 1., # q has been scaled already softmax_scale=1., # q has been scaled already
) )
# [*, B, N, H, C] # [*, B, N, H, C]
......
...@@ -20,7 +20,7 @@ from typing import Optional, List ...@@ -20,7 +20,7 @@ from typing import Optional, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm, Attention from openfold.model.primitives import LayerNorm, Attention
from openfold.model.dropout import ( from openfold.model.dropout import (
DropoutRowwise, DropoutRowwise,
DropoutColumnwise, DropoutColumnwise,
...@@ -46,7 +46,6 @@ from openfold.utils.feats import ( ...@@ -46,7 +46,6 @@ from openfold.utils.feats import (
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
add, add,
permute_final_dims, permute_final_dims,
flatten_final_dims,
tensor_tree_map, tensor_tree_map,
) )
...@@ -200,7 +199,8 @@ class TemplatePairStackBlock(nn.Module): ...@@ -200,7 +199,8 @@ class TemplatePairStackBlock(nn.Module):
def forward(self, def forward(self,
z: torch.Tensor, z: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -226,6 +226,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -226,6 +226,7 @@ class TemplatePairStackBlock(nn.Module):
single, single,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
mask=single_mask, mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -239,6 +240,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -239,6 +240,7 @@ class TemplatePairStackBlock(nn.Module):
single, single,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
mask=single_mask, mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -355,6 +357,7 @@ class TemplatePairStack(nn.Module): ...@@ -355,6 +357,7 @@ class TemplatePairStack(nn.Module):
t: torch.tensor, t: torch.tensor,
mask: torch.tensor, mask: torch.tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -378,6 +381,7 @@ class TemplatePairStack(nn.Module): ...@@ -378,6 +381,7 @@ class TemplatePairStack(nn.Module):
b, b,
mask=mask, mask=mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
...@@ -468,7 +472,9 @@ def embed_templates_offload( ...@@ -468,7 +472,9 @@ def embed_templates_offload(
t.unsqueeze(templ_dim), t.unsqueeze(templ_dim),
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size, chunk_size=model.globals.chunk_size,
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
use_lma=model.globals.use_lma, use_lma=model.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=model.config._mask_trans, _mask_trans=model.config._mask_trans,
) )
...@@ -585,7 +591,9 @@ def embed_templates_average( ...@@ -585,7 +591,9 @@ def embed_templates_average(
t, t,
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size, chunk_size=model.globals.chunk_size,
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
use_lma=model.globals.use_lma, use_lma=model.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=model.config._mask_trans, _mask_trans=model.config._mask_trans,
) )
......
...@@ -63,6 +63,7 @@ class TriangleAttention(nn.Module): ...@@ -63,6 +63,7 @@ class TriangleAttention(nn.Module):
biases: List[torch.Tensor], biases: List[torch.Tensor],
chunk_size: int, chunk_size: int,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -77,6 +78,7 @@ class TriangleAttention(nn.Module): ...@@ -77,6 +78,7 @@ class TriangleAttention(nn.Module):
partial( partial(
self.mha, self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma use_lma=use_lma
), ),
mha_inputs, mha_inputs,
...@@ -90,6 +92,7 @@ class TriangleAttention(nn.Module): ...@@ -90,6 +92,7 @@ class TriangleAttention(nn.Module):
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -130,6 +133,7 @@ class TriangleAttention(nn.Module): ...@@ -130,6 +133,7 @@ class TriangleAttention(nn.Module):
biases, biases,
chunk_size, chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -139,6 +143,7 @@ class TriangleAttention(nn.Module): ...@@ -139,6 +143,7 @@ class TriangleAttention(nn.Module):
kv_x=x, kv_x=x,
biases=biases, biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma use_lma=use_lma
) )
......
...@@ -181,6 +181,7 @@ def trace_model_(model, sample_input): ...@@ -181,6 +181,7 @@ def trace_model_(model, sample_input):
("mask", msa_mask), ("mask", msa_mask),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)), ("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)), ("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)), ("use_lma", torch.tensor(model.globals.use_lma)),
] ]
verify_arg_order( verify_arg_order(
...@@ -201,6 +202,7 @@ def trace_model_(model, sample_input): ...@@ -201,6 +202,7 @@ def trace_model_(model, sample_input):
("m", m), ("m", m),
("mask", msa_mask), ("mask", msa_mask),
("chunk_size", torch.tensor(evoformer_chunk_size)), ("chunk_size", torch.tensor(evoformer_chunk_size)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)), ("use_lma", torch.tensor(model.globals.use_lma)),
("use_flash", torch.tensor(model.globals.use_flash)), ("use_flash", torch.tensor(model.globals.use_flash)),
] ]
...@@ -283,6 +285,7 @@ def trace_model_(model, sample_input): ...@@ -283,6 +285,7 @@ def trace_model_(model, sample_input):
("mask", pair_mask.float()), ("mask", pair_mask.float()),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)), ("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)), ("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)), ("use_lma", torch.tensor(model.globals.use_lma)),
("inplace_safe", torch.tensor(True)), ("inplace_safe", torch.tensor(True)),
] ]
...@@ -305,6 +308,7 @@ def trace_model_(model, sample_input): ...@@ -305,6 +308,7 @@ def trace_model_(model, sample_input):
("mask", pair_mask.transpose(-1, -2).float()), ("mask", pair_mask.transpose(-1, -2).float()),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)), ("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)), ("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)), ("use_lma", torch.tensor(model.globals.use_lma)),
("inplace_safe", torch.tensor(True)), ("inplace_safe", torch.tensor(True)),
] ]
......
...@@ -13,6 +13,15 @@ gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats. ...@@ -13,6 +13,15 @@ gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats.
python setup.py install python setup.py install
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH echo "Download CUTLASS, required for Deepspeed Evoformer attention kernel"
git clone https://github.com/NVIDIA/cutlass --depth 1
conda env config vars set CUTLASS_PATH=$PWD/cutlass
# This setting is used to fix a worker assignment issue during data loading # This setting is used to fix a worker assignment issue during data loading
conda env config vars set KMP_AFFINITY=none conda env config vars set KMP_AFFINITY=none
# Reactivate env so that the above environment variables take effect
conda activate $CONDA_PREFIX
export LIBRARY_PATH=$CONDA_PREFIX/lib:$LIBRARY_PATH
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
...@@ -10,7 +10,6 @@ import numpy as np ...@@ -10,7 +10,6 @@ import numpy as np
from openfold.config import model_config from openfold.config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_ from openfold.utils.import_weights import import_jax_weights_
from tests.config import consts
# Give JAX some GPU memory discipline # Give JAX some GPU memory discipline
# (by default it hogs 90% of GPU memory. This disables that behavior and also # (by default it hogs 90% of GPU memory. This disables that behavior and also
...@@ -19,6 +18,18 @@ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" ...@@ -19,6 +18,18 @@ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_PLATFORM_NAME"] = "gpu" os.environ["JAX_PLATFORM_NAME"] = "gpu"
def skip_unless_ds4s_installed():
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec(
"deepspeed.ops.deepspeed4science") is not None
return unittest.skipUnless(ds4s_is_installed, "Requires DeepSpeed with version ≥ 0.10.4")
def skip_unless_flash_attn_installed():
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
return unittest.skipUnless(fa_is_installed, "Requires Flash Attention")
def alphafold_is_installed(): def alphafold_is_installed():
return importlib.util.find_spec("alphafold") is not None return importlib.util.find_spec("alphafold") is not None
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
import numpy as np import numpy as np
from scipy.spatial.transform import Rotation from scipy.spatial.transform import Rotation
...@@ -95,3 +96,17 @@ def random_affines_4x4(dim): ...@@ -95,3 +96,17 @@ def random_affines_4x4(dim):
affines[:, 3, 3] = 1 affines[:, 3, 3] = 1
return affines.reshape(*dim, 4, 4) return affines.reshape(*dim, 4, 4)
def random_attention_inputs(batch_size, n_seq, n, no_heads, c_hidden, inf=1e9,
dtype=torch.float32, requires_grad=False):
q = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype, requires_grad=requires_grad).cuda()
kv = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype, requires_grad=requires_grad).cuda()
mask = torch.randint(0, 2, (batch_size, n_seq, 1, 1, n), dtype=dtype, requires_grad=False).cuda()
z_bias = torch.rand(batch_size, 1, no_heads, n, n, dtype=dtype, requires_grad=requires_grad).cuda()
mask_bias = inf * (mask - 1)
biases = [mask_bias, z_bias]
return q, kv, mask, biases
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests to compare components of OpenFold run with the DeepSpeed memory-efficient
attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation.
"""
import unittest
import numpy as np
import pickle
import torch
from torch.nn import functional as F
from openfold.data import data_transforms
from openfold.model.primitives import (
lecun_normal_init_,
Attention
)
from openfold.utils.tensor_utils import tensor_tree_map
from tests.config import consts
import tests.compare_utils as compare_utils
from tests.data_utils import random_template_feats, random_attention_inputs
@compare_utils.skip_unless_ds4s_installed()
class TestDeepSpeedKernel(unittest.TestCase):
def compare_attention_types(self, use_flash=False):
"""Compare attention with and without using DeepSpeed Evoformer kernel."""
batch_size = consts.batch_size
n_seq = 18
n_res = 20
c_hidden = 32
no_heads = 4
eps = 2e-2
q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
n_seq=n_seq,
n=n_res,
no_heads=no_heads,
c_hidden=c_hidden)
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
# Change output params init for testing since they are initialized with 'final' init (zeros)
# Otherwise both will just return zero.
with torch.no_grad():
lecun_normal_init_(a.linear_g.weight)
lecun_normal_init_(a.linear_o.weight)
if use_flash:
biases = [biases[0]]
flash_mask = mask.reshape(batch_size * n_seq, n_res)
real_out = a(q, kv, use_flash=True, flash_mask=flash_mask).cpu()
else:
real_out = a(q, kv, biases=biases).cpu()
ds_out = a(q, kv, biases=biases, use_deepspeed_evo_attention=True).cpu()
err = torch.max(torch.abs(ds_out - real_out))
self.assertTrue(err < eps, f'Error: {err}')
def test_ds_kernel_vs_attention_forward(self):
"""Compare regular attention vs. DeepSpeed Evoformer kernel."""
self.compare_attention_types(use_flash=False)
@compare_utils.skip_unless_flash_attn_installed()
def test_ds_kernel_vs_flash_attn_forward(self):
"""Compare Flash Attention vs. DeepSpeed Evoformer kernel."""
self.compare_attention_types(use_flash=True)
def test_ds_kernel_vs_attention_backward(self):
"""Compare backward pass for regular attention vs. DeepSpeed Evoformer kernel."""
batch_size = consts.batch_size
n_seq = 18
n_res = 20
c_hidden = 32
no_heads = 4
eps = consts.eps
q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
n_seq=n_seq,
n=n_res,
no_heads=no_heads,
c_hidden=c_hidden,
requires_grad=True)
attn = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
lecun_normal_init_(attn.linear_g.weight)
lecun_normal_init_(attn.linear_o.weight)
def clone(t):
# Create new params, clone values
t = t.clone()
if t.requires_grad:
t.retain_grad()
return t
def init_attn():
# Create new attention object with same initial weights
a_clone = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
a_clone.load_state_dict(attn.state_dict())
return a_clone
# Clone param values and run attention with DS kernel
q_repro = clone(q)
kv_repro = clone(kv)
biases_repro = [clone(b) for b in biases]
a_repro = init_attn()
out_repro = a_repro(q_repro, kv_repro, biases=biases_repro, use_deepspeed_evo_attention=True)
loss_repro = torch.mean(out_repro)
loss_repro.backward()
q_gt = clone(q)
kv_gt = clone(kv)
biases_gt = [clone(b) for b in biases]
# Clone param values and run attention without DS kernel
a_gt = init_attn()
out_gt = a_gt(q_gt, kv_gt, biases=biases_gt)
loss_gt = torch.mean(out_gt)
loss_gt.backward()
# Compare the grads of attention inputs
pairs = zip([q_repro, kv_repro, biases_repro[1]],
[q_gt, kv_gt, biases_gt[1]])
for i, item in enumerate(pairs):
t_repro, t_gt = item
err = torch.max(torch.abs(t_repro.grad.cpu() - t_gt.grad.cpu()))
self.assertTrue(err < eps, f'Error item #{i}: {err}')
# Compare the grads of model weights
a_repro_params = dict(a_repro.named_parameters())
a_gt_params = dict(a_gt.named_parameters())
for name in a_gt_params.keys():
t_repro = a_repro_params[name]
t_gt = a_gt_params[name]
err = torch.max(torch.abs(t_repro.grad.cpu() - t_gt.grad.cpu()))
self.assertTrue(err < eps, f'Error item {name}: {err}')
def compare_evoformer(self, dtype, eps):
"""
Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
Set dtype to confirm the kernel can be used during both training (BF16) and inference (FP32),
since the kernel itself can run with either BF16 or FP16 precision.
"""
n_res = 20
n_seq = 18
c_m_shape = (consts.c_m,)
c_z_shape = (consts.c_z,)
activations = {
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
"pair": torch.rand(n_res, n_res, consts.c_z, device='cuda', dtype=dtype)
}
masks = {
"msa": torch.randint(0, 2, (n_seq, n_res), device='cuda', dtype=dtype),
"pair": torch.randint(0, 2, (n_res, n_res), device='cuda', dtype=dtype),
}
with torch.cuda.amp.autocast(dtype=dtype):
model = compare_utils.get_global_pretrained_openfold()
out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
activations["msa"],
activations["pair"],
masks["msa"],
masks["pair"],
use_deepspeed_evo_attention=False,
chunk_size=4,
_mask_trans=False,
inplace_safe=False,
)
# In practice, layer norms applied later in the network make any
# kernel rounding errors negligible
out_repro_msa = F.layer_norm(out_repro_msa, c_m_shape).cpu()
out_repro_pair = F.layer_norm(out_repro_pair, c_z_shape).cpu()
out_repro_msa_ds, out_repro_pair_ds = model.evoformer.blocks[0](
activations["msa"],
activations["pair"],
masks["msa"],
masks["pair"],
use_deepspeed_evo_attention=True,
chunk_size=4,
_mask_trans=False,
inplace_safe=False,
)
out_repro_msa_ds = F.layer_norm(out_repro_msa_ds, c_m_shape).cpu()
out_repro_pair_ds = F.layer_norm(out_repro_pair_ds, c_z_shape).cpu()
err = torch.mean(torch.abs(out_repro_msa - out_repro_msa_ds))
self.assertTrue(err < eps, f'MSA Error: {err}')
err = torch.mean(torch.abs(out_repro_pair - out_repro_pair_ds))
self.assertTrue(err < eps, f'Pair Error {err}')
def test_compare_evoformer_bf16(self):
"""Run evoformer comparison test with BF16 precision."""
self.compare_evoformer(dtype=torch.bfloat16, eps=4e-2)
def test_compare_evoformer_fp32(self):
"""Run evoformer comparison test with FP32 precision."""
self.compare_evoformer(dtype=torch.float32, eps=2e-2)
def test_compare_template_stack(self):
"""
Compare Template Stack output with and without using DeepSpeed Evoformer attention kernel.
Kernel can be used for Triangle Attention in the Template Pair Stack.
"""
n_templ = consts.n_templ
n_res = 20
eps = 2e-2
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = False
out_repro = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
inplace_safe=False
)
out_repro = out_repro["template_pair_embedding"].cpu()
model.globals.use_deepspeed_evo_attention = True
out_repro_ds = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
inplace_safe=False
)
out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()
err = torch.max(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error {err}')
def test_compare_model(self):
"""
Run full model with and without using DeepSpeed Evoformer attention kernel
and compare output coordinates.
"""
eps = 0.5
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp)
# atom37_to_atom14 doesn't like batches
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
batch["no_recycling_iters"] = np.array([3., 3., 3., 3., ])
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
batch["aatype"] = batch["aatype"].long()
batch["template_aatype"] = batch["template_aatype"].long()
batch["extra_msa"] = batch["extra_msa"].long()
batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14"
].long()
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch)
)
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = False
out_repro = model(batch)
# Enable kernel
model.globals.use_deepspeed_evo_attention = True
out_repro_ds = model(batch)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
out_repro_ds = tensor_tree_map(lambda t: t.cpu(), out_repro_ds)
out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0)
err = torch.max(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error: {err}')
if __name__ == "__main__":
unittest.main()
...@@ -13,40 +13,40 @@ ...@@ -13,40 +13,40 @@
# limitations under the License. # limitations under the License.
import torch import torch
import numpy as np
import unittest import unittest
from openfold.model.primitives import ( from openfold.model.primitives import (
lecun_normal_init_,
Attention, Attention,
) )
from tests.config import consts from tests.config import consts
from tests.data_utils import random_attention_inputs
class TestLMA(unittest.TestCase): class TestLMA(unittest.TestCase):
def test_lma_vs_attention(self): def test_lma_vs_attention(self):
batch_size = consts.batch_size c_hidden = 32
c_hidden = 32
n = 2**12
no_heads = 4 no_heads = 4
q = torch.rand(batch_size, n, c_hidden).cuda() q, kv, _, biases = random_attention_inputs(batch_size=consts.batch_size,
kv = torch.rand(batch_size, n, c_hidden).cuda() n_seq=consts.n_seq,
n=2**12,
no_heads=no_heads,
c_hidden=c_hidden)
bias = [torch.rand(no_heads, 1, n)]
bias = [b.cuda() for b in bias]
gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads)
a = Attention( a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda() ).cuda()
with torch.no_grad(): with torch.no_grad():
l = a(q, kv, biases=bias, use_lma=True) lecun_normal_init_(a.linear_g.weight)
real = a(q, kv, biases=bias) lecun_normal_init_(a.linear_o.weight)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps) l = a(q, kv, biases=biases, use_lma=True).cpu()
real = a(q, kv, biases=biases).cpu()
err = torch.max(torch.abs(l - real))
self.assertTrue(err < consts.eps, f'Error: {err}')
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment