Commit 6ebcd8b4 authored by Christina Floristean's avatar Christina Floristean
Browse files

Correct unit tests to run attention functions instead of Attention module in...

Correct unit tests to run attention functions instead of Attention module in order to avoid 'final' init on outputs
parent 710088d9
...@@ -26,7 +26,8 @@ import numpy as np ...@@ -26,7 +26,8 @@ import numpy as np
import pickle import pickle
from openfold.model.primitives import ( from openfold.model.primitives import (
Attention, _attention,
_deepspeed_evo_attn
) )
from tests.config import consts from tests.config import consts
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
...@@ -43,22 +44,26 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -43,22 +44,26 @@ class TestDeepSpeedKernel(unittest.TestCase):
n = 2 ** 12 n = 2 ** 12
n_seq = 12 n_seq = 12
no_heads = 4 no_heads = 4
dtype = torch.bfloat16
q = torch.rand(batch_size, n_seq, n, c_hidden).cuda() q = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda()
kv = torch.rand(batch_size, n_seq, n, c_hidden).cuda() k = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda()
v = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda()
bias = [torch.rand(batch_size, n_seq, 1, 1, n), torch.rand(batch_size, 1, no_heads, n, n)] bias = [torch.rand(batch_size, n_seq, 1, 1, n), torch.rand(batch_size, 1, no_heads, n, n)]
bias = [b.cuda() for b in bias] bias = [b.to(dtype=dtype).cuda() for b in bias]
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad(): with torch.no_grad():
l = a(q, kv, biases=bias, use_deepspeed_evo_attention=True) l = _deepspeed_evo_attn(q, k, v, biases=bias).cpu()
real = a(q, kv, biases=bias)
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
real = _attention(q, k, v, biases=bias)
real = real.transpose(-2, -3).cpu()
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps) err = torch.max(torch.abs(l - real))
self.assertTrue(err < consts.eps, f'Error: {err}')
def compare_evoformer(self, dtype): def compare_evoformer(self, dtype):
""" """
...@@ -112,17 +117,14 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -112,17 +117,14 @@ class TestDeepSpeedKernel(unittest.TestCase):
self.assertTrue(torch.allclose(torch.abs(out_repro_msa), torch.abs(out_repro_msa_ds), atol=eps)) self.assertTrue(torch.allclose(torch.abs(out_repro_msa), torch.abs(out_repro_msa_ds), atol=eps))
self.assertTrue(torch.allclose(torch.abs(out_repro_pair), torch.abs(out_repro_pair_ds), atol=eps)) self.assertTrue(torch.allclose(torch.abs(out_repro_pair), torch.abs(out_repro_pair_ds), atol=eps))
@unittest.skip('Temporarily disabled')
def test_compare_evoformer_bf16(self): def test_compare_evoformer_bf16(self):
"""Run evoformer comparison test with BF16 precision.""" """Run evoformer comparison test with BF16 precision."""
self.compare_evoformer(torch.bfloat16) self.compare_evoformer(torch.bfloat16)
@unittest.skip('Temporarily disabled')
def test_compare_evoformer_fp32(self): def test_compare_evoformer_fp32(self):
"""Run evoformer comparison test with FP32 precision.""" """Run evoformer comparison test with FP32 precision."""
self.compare_evoformer(torch.float32) self.compare_evoformer(torch.float32)
@unittest.skip('Temporarily disabled')
def test_compare_model(self): def test_compare_model(self):
""" """
Run full model with and without using DeepSpeed Evoformer attention kernel Run full model with and without using DeepSpeed Evoformer attention kernel
......
...@@ -17,7 +17,10 @@ import numpy as np ...@@ -17,7 +17,10 @@ import numpy as np
import unittest import unittest
from openfold.model.primitives import ( from openfold.model.primitives import (
Attention, _lma,
_attention,
DEFAULT_LMA_Q_CHUNK_SIZE,
DEFAULT_LMA_KV_CHUNK_SIZE
) )
from tests.config import consts from tests.config import consts
...@@ -27,26 +30,26 @@ class TestLMA(unittest.TestCase): ...@@ -27,26 +30,26 @@ class TestLMA(unittest.TestCase):
batch_size = consts.batch_size batch_size = consts.batch_size
c_hidden = 32 c_hidden = 32
n = 2**12 n = 2**12
n_seq = 12
no_heads = 4 no_heads = 4
q = torch.rand(batch_size, n, c_hidden).cuda() q = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda()
kv = torch.rand(batch_size, n, c_hidden).cuda() k = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda()
v = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda()
bias = [torch.rand(no_heads, 1, n)] bias = [torch.rand(batch_size, n_seq, 1, 1, n), torch.rand(batch_size, 1, no_heads, n, n)]
bias = [b.cuda() for b in bias] biases = [b.cuda() for b in bias]
gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads)
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad(): with torch.no_grad():
l = a(q, kv, biases=bias, use_lma=True) lma_biases = [
real = a(q, kv, biases=bias) b.expand(b.shape[:-2] + (q.shape[-2],) + (k.shape[-2],))
for b in biases
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps) ]
l = _lma(q, k, v, lma_biases, DEFAULT_LMA_Q_CHUNK_SIZE, DEFAULT_LMA_KV_CHUNK_SIZE).cpu()
real = _attention(q, k, v, biases).cpu()
err = torch.max(torch.abs(l - real))
self.assertTrue(err < consts.eps, f'Error: {err}')
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment