"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "c56d0dea2c6f8a917e1fb4f98dddbd119558cbe7"
Commit 2bf18520 authored by Christina Floristean's avatar Christina Floristean
Browse files

Clean up DS kernel integration and test, add cutlass to installation procedure

parent a6703606
...@@ -28,8 +28,6 @@ dependencies: ...@@ -28,8 +28,6 @@ dependencies:
- wandb==0.12.21 - wandb==0.12.21
- modelcif==0.7 - modelcif==0.7
- git+https://github.com/NVIDIA/dllogger.git - git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/NVIDIA/cutlass.git
- git+https://github.com/microsoft/DeepSpeed.git - git+https://github.com/microsoft/DeepSpeed.git
# TODO: Replace above when version becomes available # TODO: Replace above when version becomes available
# - deepspeed==0.10.3 # - deepspeed==0.10.4
...@@ -367,12 +367,15 @@ config = mlc.ConfigDict( ...@@ -367,12 +367,15 @@ config = mlc.ConfigDict(
"globals": { "globals": {
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
# Use DeepSpeed memory-efficient attention kernel. Mutually
# exclusive with use_lma and use_flash.
"use_deepspeed_evo_attention": False, "use_deepspeed_evo_attention": False,
# Use Staats & Rabe's low-memory attention algorithm. Mutually # Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash. # exclusive with use_deepspeed_evo_attention and use_flash.
"use_lma": False, "use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with # Use FlashAttention in selected modules. Mutually exclusive with
# use_lma. Doesn't work that well on long sequences (>1000 residues). # use_deepspeed_evo_attention and use_lma. Doesn't work that well
# on long sequences (>1000 residues).
"use_flash": False, "use_flash": False,
"offload_inference": False, "offload_inference": False,
"c_z": c_z, "c_z": c_z,
......
...@@ -801,10 +801,15 @@ class EvoformerStack(nn.Module): ...@@ -801,10 +801,15 @@ class EvoformerStack(nn.Module):
chunk_size: chunk_size:
Inference-time subbatch size. Acts as a minimum if Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference use_deepspeed_evo_attention:
Whether to use DeepSpeed memory efficient kernel.
Mutually exclusive with use_lma and use_flash.
use_lma:
Whether to use low-memory attention during inference.
Mutually exclusive with use_flash and use_deepspeed_evo_attention.
use_flash: use_flash:
Whether to use FlashAttention where possible. Mutually Whether to use FlashAttention where possible. Mutually
exclusive with use_lma. exclusive with use_lma and use_deepspeed_evo_attention.
Returns: Returns:
m: m:
[*, N_seq, N_res, C_m] MSA embedding [*, N_seq, N_res, C_m] MSA embedding
...@@ -1000,6 +1005,7 @@ class ExtraMSAStack(nn.Module): ...@@ -1000,6 +1005,7 @@ class ExtraMSAStack(nn.Module):
z: z:
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules chunk_size: Inference-time subbatch size for Evoformer modules
use_deepspeed_evo_attention: Whether to use DeepSpeed memory-efficient kernel
use_lma: Whether to use low-memory attention during inference use_lma: Whether to use low-memory attention during inference
msa_mask: msa_mask:
Optional [*, N_extra, N_res] MSA mask Optional [*, N_extra, N_res] MSA mask
......
...@@ -20,7 +20,9 @@ import numpy as np ...@@ -20,7 +20,9 @@ import numpy as np
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if deepspeed_is_installed: if deepspeed_is_installed:
import deepspeed import deepspeed
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
if importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None:
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
fa_is_installed = importlib.util.find_spec("flash_attn") is not None fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed: if fa_is_installed:
...@@ -375,7 +377,8 @@ class Attention(nn.Module): ...@@ -375,7 +377,8 @@ class Attention(nn.Module):
def _prep_qkv(self, def _prep_qkv(self,
q_x: torch.Tensor, q_x: torch.Tensor,
kv_x: torch.Tensor kv_x: torch.Tensor,
transpose_qkv_dims: bool = True
) -> Tuple[ ) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor torch.Tensor, torch.Tensor, torch.Tensor
]: ]:
...@@ -389,10 +392,11 @@ class Attention(nn.Module): ...@@ -389,10 +392,11 @@ class Attention(nn.Module):
k = k.view(k.shape[:-1] + (self.no_heads, -1)) k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1)) v = v.view(v.shape[:-1] + (self.no_heads, -1))
# [*, H, Q/K, C_hidden] if transpose_qkv_dims:
q = q.transpose(-2, -3) # [*, H, Q/K, C_hidden]
k = k.transpose(-2, -3) q = q.transpose(-2, -3)
v = v.transpose(-2, -3) k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
q /= math.sqrt(self.c_hidden) q /= math.sqrt(self.c_hidden)
...@@ -479,10 +483,10 @@ class Attention(nn.Module): ...@@ -479,10 +483,10 @@ class Attention(nn.Module):
if biases is None: if biases is None:
biases = [] biases = []
# [*, H, Q/K, C_hidden] # DeepSpeed attention kernel expects Q/K/V of shape [*, Q/K, H, C_hidden]
q, k, v = self._prep_qkv(q_x, kv_x) # All other attention modules expect Q/K/V of shape [*, H, Q/K, C_hidden]
q, k, v = self._prep_qkv(q_x, kv_x, transpose_qkv_dims=not use_deepspeed_evo_attention)
# [*, Q, H, C_hidden]
if is_fp16_enabled(): if is_fp16_enabled():
use_memory_efficient_kernel = False use_memory_efficient_kernel = False
...@@ -495,17 +499,32 @@ class Attention(nn.Module): ...@@ -495,17 +499,32 @@ class Attention(nn.Module):
o = attention_core(q, k, v, *((biases + [None] * 2)[:2])) o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
o = o.transpose(-2, -3) o = o.transpose(-2, -3)
elif use_deepspeed_evo_attention: elif use_deepspeed_evo_attention:
q = q.transpose(-2, -3) if len(biases) > 2:
k = k.transpose(-2, -3) raise ValueError(
v = v.transpose(-2, -3) "If use_deepspeed_evo_attention is True, you may only "
"provide up to two bias terms"
)
add_batch_dim = len(q.shape) < 5 orig_shape = q.shape
if add_batch_dim: no_batch_dims = len(orig_shape[:-3])
q = q.unsqueeze(0) if no_batch_dims > 2:
k = k.unsqueeze(0) raise ValueError(
v = v.unsqueeze(0) f"Q is of shape {list(orig_shape)} but must be "
biases = [b.unsqueeze(0) for b in biases] "of shape [B, N, Q/K, H, C_hidden] if "
"use_deepspeed_evo_attention is True."
)
# Bypass asserts for bias shapes in DS4Sci_EvoformerAttention()
# by adding batch and N_seq dims if needed.
if no_batch_dims < 2:
addl_dims = (1,) * (2 - no_batch_dims)
q = q.view(*(addl_dims + q.shape))
k = k.view(*(addl_dims + k.shape))
v = v.view(*(addl_dims + v.shape))
biases = [b.view(*(addl_dims + b.shape)) for b in biases]
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]: if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16), o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
...@@ -517,8 +536,7 @@ class Attention(nn.Module): ...@@ -517,8 +536,7 @@ class Attention(nn.Module):
else: else:
o = DS4Sci_EvoformerAttention(q, k, v, biases) o = DS4Sci_EvoformerAttention(q, k, v, biases)
if add_batch_dim: o = o.view(orig_shape)
o = o.squeeze(0)
elif use_lma: elif use_lma:
biases = [ biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
......
...@@ -181,6 +181,7 @@ def trace_model_(model, sample_input): ...@@ -181,6 +181,7 @@ def trace_model_(model, sample_input):
("mask", msa_mask), ("mask", msa_mask),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)), ("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)), ("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)), ("use_lma", torch.tensor(model.globals.use_lma)),
] ]
verify_arg_order( verify_arg_order(
...@@ -201,6 +202,7 @@ def trace_model_(model, sample_input): ...@@ -201,6 +202,7 @@ def trace_model_(model, sample_input):
("m", m), ("m", m),
("mask", msa_mask), ("mask", msa_mask),
("chunk_size", torch.tensor(evoformer_chunk_size)), ("chunk_size", torch.tensor(evoformer_chunk_size)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)), ("use_lma", torch.tensor(model.globals.use_lma)),
("use_flash", torch.tensor(model.globals.use_flash)), ("use_flash", torch.tensor(model.globals.use_flash)),
] ]
...@@ -283,6 +285,7 @@ def trace_model_(model, sample_input): ...@@ -283,6 +285,7 @@ def trace_model_(model, sample_input):
("mask", pair_mask.float()), ("mask", pair_mask.float()),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)), ("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)), ("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)), ("use_lma", torch.tensor(model.globals.use_lma)),
("inplace_safe", torch.tensor(True)), ("inplace_safe", torch.tensor(True)),
] ]
...@@ -305,6 +308,7 @@ def trace_model_(model, sample_input): ...@@ -305,6 +308,7 @@ def trace_model_(model, sample_input):
("mask", pair_mask.transpose(-1, -2).float()), ("mask", pair_mask.transpose(-1, -2).float()),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)), ("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)), ("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_lma", torch.tensor(model.globals.use_lma)), ("use_lma", torch.tensor(model.globals.use_lma)),
("inplace_safe", torch.tensor(True)), ("inplace_safe", torch.tensor(True)),
] ]
......
...@@ -25,6 +25,11 @@ git checkout 5b838a8bef ...@@ -25,6 +25,11 @@ git checkout 5b838a8bef
python3 setup.py install python3 setup.py install
cd $CUR_DIR cd $CUR_DIR
echo "Attempting to download CUTLASS, required for Deepspeed Evoformer attention kernel"
git clone https://github.com/NVIDIA/cutlass.git
conda env config vars set CUTLASS_PATH=$PWD/cutlass
source scripts/activate_conda_env.sh
# Install DeepMind's OpenMM patch # Install DeepMind's OpenMM patch
OPENFOLD_DIR=$PWD OPENFOLD_DIR=$PWD
pushd lib/conda/envs/$ENV_NAME/lib/python3.9/site-packages/ \ pushd lib/conda/envs/$ENV_NAME/lib/python3.9/site-packages/ \
......
...@@ -12,6 +12,11 @@ ...@@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Unit tests to compare components of OpenFold run with the DeepSpeed memory-efficient
attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation.
"""
import torch import torch
import unittest import unittest
import numpy as np import numpy as np
...@@ -22,17 +27,26 @@ from openfold.model.primitives import ( ...@@ -22,17 +27,26 @@ from openfold.model.primitives import (
) )
from tests.config import consts from tests.config import consts
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.data_utils import (
random_template_feats,
random_extra_msa_feats,
)
from openfold.config import model_config
from openfold.data import data_transforms from openfold.data import data_transforms
from openfold.model.model import AlphaFold
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
class TestDeepSpeedKernel(unittest.TestCase): class TestDeepSpeedKernel(unittest.TestCase):
def test_ds_kernel_vs_attention(self): def test_ds_kernel_vs_attention(self):
"""Compare attention with and without using DeepSpeed Evoformer kernel."""
batch_size = consts.batch_size batch_size = consts.batch_size
c_hidden = 32 c_hidden = 32
n = 2 ** 12 n = 2 ** 12
n_seq = 12 n_seq = 12
no_heads = 4 no_heads = 4
eps = 2e-2
q = torch.rand(batch_size, n_seq, n, c_hidden).cuda() q = torch.rand(batch_size, n_seq, n, c_hidden).cuda()
kv = torch.rand(batch_size, n_seq, n, c_hidden).cuda() kv = torch.rand(batch_size, n_seq, n, c_hidden).cuda()
...@@ -48,11 +62,17 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -48,11 +62,17 @@ class TestDeepSpeedKernel(unittest.TestCase):
l = a(q, kv, biases=bias, use_deepspeed_evo_attention=True) l = a(q, kv, biases=bias, use_deepspeed_evo_attention=True)
real = a(q, kv, biases=bias) real = a(q, kv, biases=bias)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps) self.assertTrue(torch.max(torch.abs(l - real)) < eps)
def compare_evoformer(self, dtype): def compare_evoformer(self, dtype):
"""
Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
Set dtype to confirm the kernel can be used during both training (BF16) and inference (FP32),
since the kernel itself can run with either BF16 or FP16 precision.
"""
n_res = consts.n_res n_res = consts.n_res
n_seq = consts.n_seq n_seq = consts.n_seq
eps = 2e-2
activations = { activations = {
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype), "msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
...@@ -93,16 +113,23 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -93,16 +113,23 @@ class TestDeepSpeedKernel(unittest.TestCase):
out_repro_msa_ds = out_repro_msa_ds.cpu() out_repro_msa_ds = out_repro_msa_ds.cpu()
out_repro_pair_ds = out_repro_pair_ds.cpu() out_repro_pair_ds = out_repro_pair_ds.cpu()
self.assertTrue(torch.allclose(torch.abs(out_repro_msa), torch.abs(out_repro_msa_ds), atol=consts.eps)) self.assertTrue(torch.allclose(torch.abs(out_repro_msa), torch.abs(out_repro_msa_ds), atol=eps))
self.assertTrue(torch.allclose(torch.abs(out_repro_pair), torch.abs(out_repro_pair_ds), atol=consts.eps)) self.assertTrue(torch.allclose(torch.abs(out_repro_pair), torch.abs(out_repro_pair_ds), atol=eps))
def test_compare_evoformer_bf16(self): def test_compare_evoformer_bf16(self):
"""Run evoformer comparison test with BF16 precision."""
self.compare_evoformer(torch.bfloat16) self.compare_evoformer(torch.bfloat16)
def test_compare_evoformer_fp32(self): def test_compare_evoformer_fp32(self):
"""Run evoformer comparison test with FP32 precision."""
self.compare_evoformer(torch.float32) self.compare_evoformer(torch.float32)
def test_dry_run(self): def test_compare_model(self):
"""
Run full model with and without using DeepSpeed Evoformer attention kernel
and compare output coordinates
"""
eps = 2e-2
with open("tests/test_data/sample_feats.pickle", "rb") as fp: with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp) batch = pickle.load(fp)
...@@ -130,9 +157,20 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -130,9 +157,20 @@ class TestDeepSpeedKernel(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = True
out_repro = model(batch) out_repro = model(batch)
# Enable kernel
model.globals.use_deepspeed_evo_attention = False
out_repro_ds = model(batch)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
out_repro_ds = tensor_tree_map(lambda t: t.cpu(), out_repro_ds)
out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0)
self.assertTrue(torch.max(torch.abs(out_repro - out_repro_ds)) < eps)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment