"components/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "602352ce190bcb02013c62c2337e8b8678015699"
Commit a3de9cb9 authored by Christina Floristean's avatar Christina Floristean
Browse files

Added kernel to template pair stack and updated tests

parent a0985761
...@@ -169,6 +169,7 @@ class AlphaFold(nn.Module): ...@@ -169,6 +169,7 @@ class AlphaFold(nn.Module):
t_pair, t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
......
...@@ -20,7 +20,7 @@ from typing import Optional, List ...@@ -20,7 +20,7 @@ from typing import Optional, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm, Attention from openfold.model.primitives import LayerNorm, Attention
from openfold.model.dropout import ( from openfold.model.dropout import (
DropoutRowwise, DropoutRowwise,
DropoutColumnwise, DropoutColumnwise,
...@@ -46,7 +46,6 @@ from openfold.utils.feats import ( ...@@ -46,7 +46,6 @@ from openfold.utils.feats import (
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
add, add,
permute_final_dims, permute_final_dims,
flatten_final_dims,
tensor_tree_map, tensor_tree_map,
) )
...@@ -200,7 +199,8 @@ class TemplatePairStackBlock(nn.Module): ...@@ -200,7 +199,8 @@ class TemplatePairStackBlock(nn.Module):
def forward(self, def forward(self,
z: torch.Tensor, z: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -226,6 +226,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -226,6 +226,7 @@ class TemplatePairStackBlock(nn.Module):
single, single,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
mask=single_mask, mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -239,6 +240,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -239,6 +240,7 @@ class TemplatePairStackBlock(nn.Module):
single, single,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
mask=single_mask, mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
) )
...@@ -355,6 +357,7 @@ class TemplatePairStack(nn.Module): ...@@ -355,6 +357,7 @@ class TemplatePairStack(nn.Module):
t: torch.tensor, t: torch.tensor,
mask: torch.tensor, mask: torch.tensor,
chunk_size: int, chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False, inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -378,6 +381,7 @@ class TemplatePairStack(nn.Module): ...@@ -378,6 +381,7 @@ class TemplatePairStack(nn.Module):
b, b,
mask=mask, mask=mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe, inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
...@@ -468,6 +472,7 @@ def embed_templates_offload( ...@@ -468,6 +472,7 @@ def embed_templates_offload(
t.unsqueeze(templ_dim), t.unsqueeze(templ_dim),
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size, chunk_size=model.globals.chunk_size,
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
use_lma=model.globals.use_lma, use_lma=model.globals.use_lma,
_mask_trans=model.config._mask_trans, _mask_trans=model.config._mask_trans,
) )
...@@ -585,6 +590,7 @@ def embed_templates_average( ...@@ -585,6 +590,7 @@ def embed_templates_average(
t, t,
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size, chunk_size=model.globals.chunk_size,
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
use_lma=model.globals.use_lma, use_lma=model.globals.use_lma,
_mask_trans=model.config._mask_trans, _mask_trans=model.config._mask_trans,
) )
......
...@@ -10,7 +10,6 @@ import numpy as np ...@@ -10,7 +10,6 @@ import numpy as np
from openfold.config import model_config from openfold.config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_ from openfold.utils.import_weights import import_jax_weights_
from tests.config import consts
# Give JAX some GPU memory discipline # Give JAX some GPU memory discipline
# (by default it hogs 90% of GPU memory. This disables that behavior and also # (by default it hogs 90% of GPU memory. This disables that behavior and also
...@@ -19,6 +18,18 @@ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" ...@@ -19,6 +18,18 @@ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_PLATFORM_NAME"] = "gpu" os.environ["JAX_PLATFORM_NAME"] = "gpu"
def skip_unless_ds4s_installed():
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec(
"deepspeed.ops.deepspeed4science") is not None
return unittest.skipUnless(ds4s_is_installed, "Requires DeepSpeed with version ≥ 0.10.4")
def skip_unless_flash_attn_installed():
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
return unittest.skipUnless(fa_is_installed, "Requires Flash Attention")
def alphafold_is_installed(): def alphafold_is_installed():
return importlib.util.find_spec("alphafold") is not None return importlib.util.find_spec("alphafold") is not None
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
import numpy as np import numpy as np
from scipy.spatial.transform import Rotation from scipy.spatial.transform import Rotation
...@@ -95,3 +96,14 @@ def random_affines_4x4(dim): ...@@ -95,3 +96,14 @@ def random_affines_4x4(dim):
affines[:, 3, 3] = 1 affines[:, 3, 3] = 1
return affines.reshape(*dim, 4, 4) return affines.reshape(*dim, 4, 4)
def random_attention_inputs(batch_size, n_seq, n, no_heads, c_hidden, inf=1e9, dtype=torch.float32):
q = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype).cuda()
kv = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype).cuda()
mask = torch.randint(0, 2, (batch_size, n_seq, 1, 1, n), dtype=dtype).cuda()
biases = [inf * (mask - 1), torch.rand(batch_size, 1, no_heads, n, n)]
biases = [b.to(dtype=dtype).cuda() for b in biases]
return q, kv, mask, biases
...@@ -22,45 +22,63 @@ import unittest ...@@ -22,45 +22,63 @@ import unittest
import numpy as np import numpy as np
import pickle import pickle
from openfold.data import data_transforms
from openfold.model.primitives import ( from openfold.model.primitives import (
_attention, lecun_normal_init_,
_deepspeed_evo_attn Attention,
) )
from openfold.utils.tensor_utils import tensor_tree_map
from tests.config import consts from tests.config import consts
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.data_utils import random_template_feats, random_attention_inputs
from openfold.data import data_transforms
from openfold.utils.tensor_utils import tensor_tree_map
@compare_utils.skip_unless_ds4s_installed()
class TestDeepSpeedKernel(unittest.TestCase): class TestDeepSpeedKernel(unittest.TestCase):
def test_ds_kernel_vs_attention(self): def compare_attention_types(self, use_flash=False):
"""Compare attention with and without using DeepSpeed Evoformer kernel.""" """Compare attention with and without using DeepSpeed Evoformer kernel."""
batch_size = consts.batch_size batch_size = consts.batch_size
c_hidden = 32 n_seq = consts.n_seq
n = 2 ** 12 n = 2 ** 12
n_seq = 12 c_hidden = 32
no_heads = 4 no_heads = 4
dtype = torch.bfloat16 eps = 2e-2
q = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda() q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
k = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda() n_seq=n_seq,
v = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda() n=n,
no_heads=no_heads,
c_hidden=c_hidden)
bias = [torch.rand(batch_size, n_seq, 1, 1, n), torch.rand(batch_size, 1, no_heads, n, n)] a = Attention(
bias = [b.to(dtype=dtype).cuda() for b in bias] c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad(): with torch.no_grad():
l = _deepspeed_evo_attn(q, k, v, biases=bias).cpu() lecun_normal_init_(a.linear_g.weight)
lecun_normal_init_(a.linear_o.weight)
q = q.transpose(-2, -3)
k = k.transpose(-2, -3) if use_flash:
v = v.transpose(-2, -3) biases = [biases[0]]
real = _attention(q, k, v, biases=bias) flash_mask = mask.reshape(batch_size * n_seq, n)
real = real.transpose(-2, -3).cpu() real_out = a(q, kv, use_flash=True, flash_mask=flash_mask).cpu()
else:
real_out = a(q, kv, biases=biases).cpu()
ds_out = a(q, kv, biases=biases, use_deepspeed_evo_attention=True).cpu()
err = torch.max(torch.abs(l - real)) err = torch.max(torch.abs(ds_out - real_out))
self.assertTrue(err < consts.eps, f'Error: {err}') self.assertTrue(err < eps, f'Error: {err}')
def test_ds_kernel_vs_attention(self):
"""Compare regular attention vs. DeepSpeed Evoformer kernel."""
self.compare_attention_types(use_flash=False)
@compare_utils.skip_unless_flash_attn_installed()
def test_ds_kernel_vs_flash_attention(self):
"""Compare Flash Attention vs. DeepSpeed Evoformer kernel."""
self.compare_attention_types(use_flash=True)
def compare_evoformer(self, dtype): def compare_evoformer(self, dtype):
""" """
...@@ -70,7 +88,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -70,7 +88,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
""" """
n_res = 20 n_res = 20
n_seq = 18 n_seq = 18
eps = 2e-2 eps = 0.5
activations = { activations = {
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype), "msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
...@@ -111,8 +129,11 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -111,8 +129,11 @@ class TestDeepSpeedKernel(unittest.TestCase):
out_repro_msa_ds = out_repro_msa_ds.cpu() out_repro_msa_ds = out_repro_msa_ds.cpu()
out_repro_pair_ds = out_repro_pair_ds.cpu() out_repro_pair_ds = out_repro_pair_ds.cpu()
self.assertTrue(torch.allclose(torch.abs(out_repro_msa), torch.abs(out_repro_msa_ds), atol=eps)) err = torch.mean(torch.abs(out_repro_msa - out_repro_msa_ds))
self.assertTrue(torch.allclose(torch.abs(out_repro_pair), torch.abs(out_repro_pair_ds), atol=eps)) self.assertTrue(err < eps, f'MSA Error: {err}')
err = torch.mean(torch.abs(out_repro_pair - out_repro_pair_ds))
self.assertTrue(err < eps, f'Pair Error {err}')
def test_compare_evoformer_bf16(self): def test_compare_evoformer_bf16(self):
"""Run evoformer comparison test with BF16 precision.""" """Run evoformer comparison test with BF16 precision."""
...@@ -122,12 +143,54 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -122,12 +143,54 @@ class TestDeepSpeedKernel(unittest.TestCase):
"""Run evoformer comparison test with FP32 precision.""" """Run evoformer comparison test with FP32 precision."""
self.compare_evoformer(torch.float32) self.compare_evoformer(torch.float32)
def test_compare_template_stack(self):
"""
Compare Template Stack output with and without using DeepSpeed Evoformer attention kernel.
Kernel can be used for Triangle Attention in the Template Pair Stack.
"""
n_templ = consts.n_templ
n_res = 20
eps = 2e-2
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = False
out_repro = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
inplace_safe=False
)
out_repro = out_repro["template_pair_embedding"].cpu()
model.globals.use_deepspeed_evo_attention = True
out_repro_ds = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
inplace_safe=False
)
out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()
err = torch.max(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error {err}')
def test_compare_model(self): def test_compare_model(self):
""" """
Run full model with and without using DeepSpeed Evoformer attention kernel Run full model with and without using DeepSpeed Evoformer attention kernel
and compare output coordinates and compare output coordinates
""" """
eps = 2e-2 eps = 0.5
with open("tests/test_data/sample_feats.pickle", "rb") as fp: with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp) batch = pickle.load(fp)
......
...@@ -13,40 +13,37 @@ ...@@ -13,40 +13,37 @@
# limitations under the License. # limitations under the License.
import torch import torch
import numpy as np
import unittest import unittest
from openfold.model.primitives import ( from openfold.model.primitives import (
_lma, lecun_normal_init_,
_attention, Attention,
DEFAULT_LMA_Q_CHUNK_SIZE,
DEFAULT_LMA_KV_CHUNK_SIZE
) )
from tests.config import consts from tests.config import consts
from tests.data_utils import random_attention_inputs
class TestLMA(unittest.TestCase): class TestLMA(unittest.TestCase):
def test_lma_vs_attention(self): def test_lma_vs_attention(self):
batch_size = consts.batch_size c_hidden = 32
c_hidden = 32
n = 2**12
n_seq = 12
no_heads = 4 no_heads = 4
q = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda() q, kv, _, biases = random_attention_inputs(batch_size=consts.batch_size,
k = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda() n_seq=consts.n_seq,
v = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda() n=2**12,
no_heads=no_heads,
c_hidden=c_hidden)
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
bias = [torch.rand(batch_size, n_seq, 1, 1, n), torch.rand(batch_size, 1, no_heads, n, n)]
biases = [b.cuda() for b in bias]
with torch.no_grad(): with torch.no_grad():
lma_biases = [ lecun_normal_init_(a.linear_g.weight)
b.expand(b.shape[:-2] + (q.shape[-2],) + (k.shape[-2],)) lecun_normal_init_(a.linear_o.weight)
for b in biases
] l = a(q, kv, biases=biases, use_lma=True).cpu()
l = _lma(q, k, v, lma_biases, DEFAULT_LMA_Q_CHUNK_SIZE, DEFAULT_LMA_KV_CHUNK_SIZE).cpu() real = a(q, kv, biases=biases).cpu()
real = _attention(q, k, v, biases).cpu()
err = torch.max(torch.abs(l - real)) err = torch.max(torch.abs(l - real))
self.assertTrue(err < consts.eps, f'Error: {err}') self.assertTrue(err < consts.eps, f'Error: {err}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment