Commit f545323c authored by Christina Floristean's avatar Christina Floristean
Browse files

Added test for backward pass

parent a3de9cb9
...@@ -98,12 +98,17 @@ def random_affines_4x4(dim): ...@@ -98,12 +98,17 @@ def random_affines_4x4(dim):
return affines.reshape(*dim, 4, 4) return affines.reshape(*dim, 4, 4)
def random_attention_inputs(batch_size, n_seq, n, no_heads, c_hidden, inf=1e9, dtype=torch.float32): def random_attention_inputs(batch_size, n_seq, n, no_heads, c_hidden, inf=1e9,
q = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype).cuda() dtype=torch.float32, requires_grad=False):
kv = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype).cuda() q = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype, requires_grad=requires_grad).cuda()
kv = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype, requires_grad=requires_grad).cuda()
mask = torch.randint(0, 2, (batch_size, n_seq, 1, 1, n), dtype=dtype).cuda()
biases = [inf * (mask - 1), torch.rand(batch_size, 1, no_heads, n, n)] mask = torch.randint(0, 2, (batch_size, n_seq, 1, 1, n), dtype=dtype, requires_grad=requires_grad).cuda()
biases = [b.to(dtype=dtype).cuda() for b in biases] z_bias = torch.rand(batch_size, 1, no_heads, n, n, dtype=dtype, requires_grad=requires_grad).cuda()
mask_bias = inf * (mask - 1)
if requires_grad:
mask_bias = mask_bias.detach().clone().requires_grad_()
biases = [mask_bias, z_bias]
return q, kv, mask, biases return q, kv, mask, biases
...@@ -17,15 +17,16 @@ Unit tests to compare components of OpenFold run with the DeepSpeed memory-effic ...@@ -17,15 +17,16 @@ Unit tests to compare components of OpenFold run with the DeepSpeed memory-effic
attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation. attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation.
""" """
import torch
import unittest import unittest
import numpy as np import numpy as np
import pickle import pickle
import torch
from torch.nn import functional as F
from openfold.data import data_transforms from openfold.data import data_transforms
from openfold.model.primitives import ( from openfold.model.primitives import (
lecun_normal_init_, lecun_normal_init_,
Attention, Attention
) )
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
...@@ -39,15 +40,15 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -39,15 +40,15 @@ class TestDeepSpeedKernel(unittest.TestCase):
def compare_attention_types(self, use_flash=False): def compare_attention_types(self, use_flash=False):
"""Compare attention with and without using DeepSpeed Evoformer kernel.""" """Compare attention with and without using DeepSpeed Evoformer kernel."""
batch_size = consts.batch_size batch_size = consts.batch_size
n_seq = consts.n_seq n_seq = 18
n = 2 ** 12 n_res = 20
c_hidden = 32 c_hidden = 32
no_heads = 4 no_heads = 4
eps = 2e-2 eps = 2e-2
q, kv, mask, biases = random_attention_inputs(batch_size=batch_size, q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
n_seq=n_seq, n_seq=n_seq,
n=n, n=n_res,
no_heads=no_heads, no_heads=no_heads,
c_hidden=c_hidden) c_hidden=c_hidden)
...@@ -61,7 +62,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -61,7 +62,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
if use_flash: if use_flash:
biases = [biases[0]] biases = [biases[0]]
flash_mask = mask.reshape(batch_size * n_seq, n) flash_mask = mask.reshape(batch_size * n_seq, n_res)
real_out = a(q, kv, use_flash=True, flash_mask=flash_mask).cpu() real_out = a(q, kv, use_flash=True, flash_mask=flash_mask).cpu()
else: else:
real_out = a(q, kv, biases=biases).cpu() real_out = a(q, kv, biases=biases).cpu()
...@@ -71,15 +72,79 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -71,15 +72,79 @@ class TestDeepSpeedKernel(unittest.TestCase):
err = torch.max(torch.abs(ds_out - real_out)) err = torch.max(torch.abs(ds_out - real_out))
self.assertTrue(err < eps, f'Error: {err}') self.assertTrue(err < eps, f'Error: {err}')
def test_ds_kernel_vs_attention(self): def test_ds_kernel_vs_attention_forward(self):
"""Compare regular attention vs. DeepSpeed Evoformer kernel.""" """Compare regular attention vs. DeepSpeed Evoformer kernel."""
self.compare_attention_types(use_flash=False) self.compare_attention_types(use_flash=False)
@compare_utils.skip_unless_flash_attn_installed() @compare_utils.skip_unless_flash_attn_installed()
def test_ds_kernel_vs_flash_attention(self): def test_ds_kernel_vs_flash_attn_forward(self):
"""Compare Flash Attention vs. DeepSpeed Evoformer kernel.""" """Compare Flash Attention vs. DeepSpeed Evoformer kernel."""
self.compare_attention_types(use_flash=True) self.compare_attention_types(use_flash=True)
def test_ds_kernel_vs_attention_backward(self):
"""Compare backward pass for regular attention vs. DeepSpeed Evoformer kernel."""
batch_size = consts.batch_size
n_seq = 18
n_res = 20
c_hidden = 32
no_heads = 4
eps = consts.eps
q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
n_seq=n_seq,
n=n_res,
no_heads=no_heads,
c_hidden=c_hidden,
requires_grad=True)
attn = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
lecun_normal_init_(attn.linear_g.weight)
lecun_normal_init_(attn.linear_o.weight)
def clone(t):
t = t.clone()
if t.requires_grad:
t.retain_grad()
return t
def init_attn():
a_clone = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
a_clone.load_state_dict(attn.state_dict())
return a_clone
q_repro = clone(q)
kv_repro = clone(kv)
biases_repro = [clone(b) for b in biases]
a = init_attn()
out_repro = a(q_repro, kv_repro, biases=biases_repro, use_deepspeed_evo_attention=True)
loss_repro = torch.mean(out_repro)
loss_repro.backward()
q_gt = clone(q)
kv_gt = clone(kv)
biases_gt = [clone(b) for b in biases]
a = init_attn()
out_gt = a(q_gt, kv_gt, biases=biases_gt)
loss_gt = torch.mean(out_gt)
loss_gt.backward()
pairs = zip([q_repro, kv_repro, biases_repro[0], biases_repro[1]],
[q_gt, kv_gt, biases_gt[0], biases_gt[1]])
for i, item in enumerate(pairs):
t_repro, t_gt = item
err = torch.max(torch.abs(t_repro.grad.cpu() - t_gt.grad.cpu()))
self.assertTrue(err < eps, f'Error item #{i}: {err}')
def compare_evoformer(self, dtype): def compare_evoformer(self, dtype):
""" """
Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel. Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
...@@ -88,7 +153,9 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -88,7 +153,9 @@ class TestDeepSpeedKernel(unittest.TestCase):
""" """
n_res = 20 n_res = 20
n_seq = 18 n_seq = 18
eps = 0.5 c_m_shape = (consts.c_m,)
c_z_shape = (consts.c_z,)
eps = 2e-2
activations = { activations = {
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype), "msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
...@@ -113,8 +180,10 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -113,8 +180,10 @@ class TestDeepSpeedKernel(unittest.TestCase):
inplace_safe=False, inplace_safe=False,
) )
out_repro_msa = out_repro_msa.cpu() # In practice, layer norms applied later in the network make any
out_repro_pair = out_repro_pair.cpu() # kernel rounding errors negligible
out_repro_msa = F.layer_norm(out_repro_msa, c_m_shape).cpu()
out_repro_pair = F.layer_norm(out_repro_pair, c_z_shape).cpu()
out_repro_msa_ds, out_repro_pair_ds = model.evoformer.blocks[0]( out_repro_msa_ds, out_repro_pair_ds = model.evoformer.blocks[0](
activations["msa"], activations["msa"],
...@@ -126,8 +195,8 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -126,8 +195,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
_mask_trans=False, _mask_trans=False,
inplace_safe=False, inplace_safe=False,
) )
out_repro_msa_ds = out_repro_msa_ds.cpu() out_repro_msa_ds = F.layer_norm(out_repro_msa_ds, c_m_shape).cpu()
out_repro_pair_ds = out_repro_pair_ds.cpu() out_repro_pair_ds = F.layer_norm(out_repro_pair_ds, c_z_shape).cpu()
err = torch.mean(torch.abs(out_repro_msa - out_repro_msa_ds)) err = torch.mean(torch.abs(out_repro_msa - out_repro_msa_ds))
self.assertTrue(err < eps, f'MSA Error: {err}') self.assertTrue(err < eps, f'MSA Error: {err}')
...@@ -188,7 +257,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -188,7 +257,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
def test_compare_model(self): def test_compare_model(self):
""" """
Run full model with and without using DeepSpeed Evoformer attention kernel Run full model with and without using DeepSpeed Evoformer attention kernel
and compare output coordinates and compare output coordinates.
""" """
eps = 0.5 eps = 0.5
with open("tests/test_data/sample_feats.pickle", "rb") as fp: with open("tests/test_data/sample_feats.pickle", "rb") as fp:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment