Commit a3de9cb9 authored by Christina Floristean's avatar Christina Floristean
Browse files

Added kernel to template pair stack and updated tests

parent a0985761
......@@ -169,6 +169,7 @@ class AlphaFold(nn.Module):
t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
......
......@@ -20,7 +20,7 @@ from typing import Optional, List
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.model.primitives import LayerNorm, Attention
from openfold.model.dropout import (
DropoutRowwise,
DropoutColumnwise,
......@@ -46,7 +46,6 @@ from openfold.utils.feats import (
from openfold.utils.tensor_utils import (
add,
permute_final_dims,
flatten_final_dims,
tensor_tree_map,
)
......@@ -201,6 +200,7 @@ class TemplatePairStackBlock(nn.Module):
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
......@@ -226,6 +226,7 @@ class TemplatePairStackBlock(nn.Module):
single,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
......@@ -239,6 +240,7 @@ class TemplatePairStackBlock(nn.Module):
single,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
......@@ -355,6 +357,7 @@ class TemplatePairStack(nn.Module):
t: torch.tensor,
mask: torch.tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
......@@ -378,6 +381,7 @@ class TemplatePairStack(nn.Module):
b,
mask=mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
......@@ -468,6 +472,7 @@ def embed_templates_offload(
t.unsqueeze(templ_dim),
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size,
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
use_lma=model.globals.use_lma,
_mask_trans=model.config._mask_trans,
)
......@@ -585,6 +590,7 @@ def embed_templates_average(
t,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size,
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
use_lma=model.globals.use_lma,
_mask_trans=model.config._mask_trans,
)
......
......@@ -10,7 +10,6 @@ import numpy as np
from openfold.config import model_config
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_
from tests.config import consts
# Give JAX some GPU memory discipline
# (by default it hogs 90% of GPU memory. This disables that behavior and also
......@@ -19,6 +18,18 @@ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_PLATFORM_NAME"] = "gpu"
def skip_unless_ds4s_installed():
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec(
"deepspeed.ops.deepspeed4science") is not None
return unittest.skipUnless(ds4s_is_installed, "Requires DeepSpeed with version ≥ 0.10.4")
def skip_unless_flash_attn_installed():
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
return unittest.skipUnless(fa_is_installed, "Requires Flash Attention")
def alphafold_is_installed():
return importlib.util.find_spec("alphafold") is not None
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
from scipy.spatial.transform import Rotation
......@@ -95,3 +96,14 @@ def random_affines_4x4(dim):
affines[:, 3, 3] = 1
return affines.reshape(*dim, 4, 4)
def random_attention_inputs(batch_size, n_seq, n, no_heads, c_hidden, inf=1e9, dtype=torch.float32):
q = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype).cuda()
kv = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype).cuda()
mask = torch.randint(0, 2, (batch_size, n_seq, 1, 1, n), dtype=dtype).cuda()
biases = [inf * (mask - 1), torch.rand(batch_size, 1, no_heads, n, n)]
biases = [b.to(dtype=dtype).cuda() for b in biases]
return q, kv, mask, biases
......@@ -22,45 +22,63 @@ import unittest
import numpy as np
import pickle
from openfold.data import data_transforms
from openfold.model.primitives import (
_attention,
_deepspeed_evo_attn
lecun_normal_init_,
Attention,
)
from openfold.utils.tensor_utils import tensor_tree_map
from tests.config import consts
import tests.compare_utils as compare_utils
from openfold.data import data_transforms
from openfold.utils.tensor_utils import tensor_tree_map
from tests.data_utils import random_template_feats, random_attention_inputs
@compare_utils.skip_unless_ds4s_installed()
class TestDeepSpeedKernel(unittest.TestCase):
def test_ds_kernel_vs_attention(self):
def compare_attention_types(self, use_flash=False):
"""Compare attention with and without using DeepSpeed Evoformer kernel."""
batch_size = consts.batch_size
c_hidden = 32
n_seq = consts.n_seq
n = 2 ** 12
n_seq = 12
c_hidden = 32
no_heads = 4
dtype = torch.bfloat16
eps = 2e-2
q = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda()
k = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda()
v = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda()
q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
n_seq=n_seq,
n=n,
no_heads=no_heads,
c_hidden=c_hidden)
bias = [torch.rand(batch_size, n_seq, 1, 1, n), torch.rand(batch_size, 1, no_heads, n, n)]
bias = [b.to(dtype=dtype).cuda() for b in bias]
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
l = _deepspeed_evo_attn(q, k, v, biases=bias).cpu()
lecun_normal_init_(a.linear_g.weight)
lecun_normal_init_(a.linear_o.weight)
if use_flash:
biases = [biases[0]]
flash_mask = mask.reshape(batch_size * n_seq, n)
real_out = a(q, kv, use_flash=True, flash_mask=flash_mask).cpu()
else:
real_out = a(q, kv, biases=biases).cpu()
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
real = _attention(q, k, v, biases=bias)
real = real.transpose(-2, -3).cpu()
ds_out = a(q, kv, biases=biases, use_deepspeed_evo_attention=True).cpu()
err = torch.max(torch.abs(l - real))
self.assertTrue(err < consts.eps, f'Error: {err}')
err = torch.max(torch.abs(ds_out - real_out))
self.assertTrue(err < eps, f'Error: {err}')
def test_ds_kernel_vs_attention(self):
"""Compare regular attention vs. DeepSpeed Evoformer kernel."""
self.compare_attention_types(use_flash=False)
@compare_utils.skip_unless_flash_attn_installed()
def test_ds_kernel_vs_flash_attention(self):
"""Compare Flash Attention vs. DeepSpeed Evoformer kernel."""
self.compare_attention_types(use_flash=True)
def compare_evoformer(self, dtype):
"""
......@@ -70,7 +88,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
"""
n_res = 20
n_seq = 18
eps = 2e-2
eps = 0.5
activations = {
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
......@@ -111,8 +129,11 @@ class TestDeepSpeedKernel(unittest.TestCase):
out_repro_msa_ds = out_repro_msa_ds.cpu()
out_repro_pair_ds = out_repro_pair_ds.cpu()
self.assertTrue(torch.allclose(torch.abs(out_repro_msa), torch.abs(out_repro_msa_ds), atol=eps))
self.assertTrue(torch.allclose(torch.abs(out_repro_pair), torch.abs(out_repro_pair_ds), atol=eps))
err = torch.mean(torch.abs(out_repro_msa - out_repro_msa_ds))
self.assertTrue(err < eps, f'MSA Error: {err}')
err = torch.mean(torch.abs(out_repro_pair - out_repro_pair_ds))
self.assertTrue(err < eps, f'Pair Error {err}')
def test_compare_evoformer_bf16(self):
"""Run evoformer comparison test with BF16 precision."""
......@@ -122,12 +143,54 @@ class TestDeepSpeedKernel(unittest.TestCase):
"""Run evoformer comparison test with FP32 precision."""
self.compare_evoformer(torch.float32)
def test_compare_template_stack(self):
"""
Compare Template Stack output with and without using DeepSpeed Evoformer attention kernel.
Kernel can be used for Triangle Attention in the Template Pair Stack.
"""
n_templ = consts.n_templ
n_res = 20
eps = 2e-2
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = False
out_repro = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
inplace_safe=False
)
out_repro = out_repro["template_pair_embedding"].cpu()
model.globals.use_deepspeed_evo_attention = True
out_repro_ds = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
inplace_safe=False
)
out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()
err = torch.max(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error {err}')
def test_compare_model(self):
"""
Run full model with and without using DeepSpeed Evoformer attention kernel
and compare output coordinates
"""
eps = 2e-2
eps = 0.5
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp)
......
......@@ -13,40 +13,37 @@
# limitations under the License.
import torch
import numpy as np
import unittest
from openfold.model.primitives import (
_lma,
_attention,
DEFAULT_LMA_Q_CHUNK_SIZE,
DEFAULT_LMA_KV_CHUNK_SIZE
lecun_normal_init_,
Attention,
)
from tests.config import consts
from tests.data_utils import random_attention_inputs
class TestLMA(unittest.TestCase):
def test_lma_vs_attention(self):
batch_size = consts.batch_size
c_hidden = 32
n = 2**12
n_seq = 12
no_heads = 4
q = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda()
k = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda()
v = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda()
q, kv, _, biases = random_attention_inputs(batch_size=consts.batch_size,
n_seq=consts.n_seq,
n=2**12,
no_heads=no_heads,
c_hidden=c_hidden)
bias = [torch.rand(batch_size, n_seq, 1, 1, n), torch.rand(batch_size, 1, no_heads, n, n)]
biases = [b.cuda() for b in bias]
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
lma_biases = [
b.expand(b.shape[:-2] + (q.shape[-2],) + (k.shape[-2],))
for b in biases
]
l = _lma(q, k, v, lma_biases, DEFAULT_LMA_Q_CHUNK_SIZE, DEFAULT_LMA_KV_CHUNK_SIZE).cpu()
real = _attention(q, k, v, biases).cpu()
lecun_normal_init_(a.linear_g.weight)
lecun_normal_init_(a.linear_o.weight)
l = a(q, kv, biases=biases, use_lma=True).cpu()
real = a(q, kv, biases=biases).cpu()
err = torch.max(torch.abs(l - real))
self.assertTrue(err < consts.eps, f'Error: {err}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment