"icp/tests/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "63d178f1ed8d1a49d4928b6012987af20288fd84"
Commit 6ebcd8b4 authored by Christina Floristean's avatar Christina Floristean
Browse files

Correct unit tests to run attention functions instead of Attention module in...

Correct unit tests to run attention functions instead of Attention module in order to avoid 'final' init on outputs
parent 710088d9
......@@ -26,7 +26,8 @@ import numpy as np
import pickle
from openfold.model.primitives import (
Attention,
_attention,
_deepspeed_evo_attn
)
from tests.config import consts
import tests.compare_utils as compare_utils
......@@ -43,22 +44,26 @@ class TestDeepSpeedKernel(unittest.TestCase):
n = 2 ** 12
n_seq = 12
no_heads = 4
dtype = torch.bfloat16
q = torch.rand(batch_size, n_seq, n, c_hidden).cuda()
kv = torch.rand(batch_size, n_seq, n, c_hidden).cuda()
q = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda()
k = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda()
v = torch.rand(batch_size, n_seq, n, no_heads, c_hidden, dtype=dtype).cuda()
bias = [torch.rand(batch_size, n_seq, 1, 1, n), torch.rand(batch_size, 1, no_heads, n, n)]
bias = [b.cuda() for b in bias]
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
bias = [b.to(dtype=dtype).cuda() for b in bias]
with torch.no_grad():
l = a(q, kv, biases=bias, use_deepspeed_evo_attention=True)
real = a(q, kv, biases=bias)
l = _deepspeed_evo_attn(q, k, v, biases=bias).cpu()
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
real = _attention(q, k, v, biases=bias)
real = real.transpose(-2, -3).cpu()
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
err = torch.max(torch.abs(l - real))
self.assertTrue(err < consts.eps, f'Error: {err}')
def compare_evoformer(self, dtype):
"""
......@@ -112,17 +117,14 @@ class TestDeepSpeedKernel(unittest.TestCase):
self.assertTrue(torch.allclose(torch.abs(out_repro_msa), torch.abs(out_repro_msa_ds), atol=eps))
self.assertTrue(torch.allclose(torch.abs(out_repro_pair), torch.abs(out_repro_pair_ds), atol=eps))
@unittest.skip('Temporarily disabled')
def test_compare_evoformer_bf16(self):
"""Run evoformer comparison test with BF16 precision."""
self.compare_evoformer(torch.bfloat16)
@unittest.skip('Temporarily disabled')
def test_compare_evoformer_fp32(self):
"""Run evoformer comparison test with FP32 precision."""
self.compare_evoformer(torch.float32)
@unittest.skip('Temporarily disabled')
def test_compare_model(self):
"""
Run full model with and without using DeepSpeed Evoformer attention kernel
......
......@@ -17,7 +17,10 @@ import numpy as np
import unittest
from openfold.model.primitives import (
Attention,
_lma,
_attention,
DEFAULT_LMA_Q_CHUNK_SIZE,
DEFAULT_LMA_KV_CHUNK_SIZE
)
from tests.config import consts
......@@ -27,26 +30,26 @@ class TestLMA(unittest.TestCase):
batch_size = consts.batch_size
c_hidden = 32
n = 2**12
n_seq = 12
no_heads = 4
q = torch.rand(batch_size, n, c_hidden).cuda()
kv = torch.rand(batch_size, n, c_hidden).cuda()
q = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda()
k = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda()
v = torch.rand(batch_size, n_seq, no_heads, n, c_hidden).cuda()
bias = [torch.rand(no_heads, 1, n)]
bias = [b.cuda() for b in bias]
gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads)
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
bias = [torch.rand(batch_size, n_seq, 1, 1, n), torch.rand(batch_size, 1, no_heads, n, n)]
biases = [b.cuda() for b in bias]
with torch.no_grad():
l = a(q, kv, biases=bias, use_lma=True)
real = a(q, kv, biases=bias)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
lma_biases = [
b.expand(b.shape[:-2] + (q.shape[-2],) + (k.shape[-2],))
for b in biases
]
l = _lma(q, k, v, lma_biases, DEFAULT_LMA_Q_CHUNK_SIZE, DEFAULT_LMA_KV_CHUNK_SIZE).cpu()
real = _attention(q, k, v, biases).cpu()
err = torch.max(torch.abs(l - real))
self.assertTrue(err < consts.eps, f'Error: {err}')
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment