Commit 1606ac08 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merge branch 'main' into multimer

parents 58d65692 67a00a6c
......@@ -15,33 +15,39 @@
import torch
import unittest
from openfold.model.primitives import Attention
from openfold.model.primitives import (
lecun_normal_init_,
Attention,
)
from tests.config import consts
from tests.data_utils import random_attention_inputs
class TestLMA(unittest.TestCase):
def test_lma_vs_attention(self):
batch_size = consts.batch_size
c_hidden = 32
n = 2 ** 12
no_heads = 4
q = torch.rand(batch_size, n, c_hidden).cuda()
kv = torch.rand(batch_size, n, c_hidden).cuda()
bias = [torch.rand(no_heads, 1, n)]
bias = [b.cuda() for b in bias]
q, kv, _, biases = random_attention_inputs(batch_size=consts.batch_size,
n_seq=consts.n_seq,
n=2 ** 12,
no_heads=no_heads,
c_hidden=c_hidden)
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
l = a(q, kv, biases=bias, use_lma=True)
real = a(q, kv, biases=bias)
lecun_normal_init_(a.linear_g.weight)
lecun_normal_init_(a.linear_o.weight)
l = a(q, kv, biases=biases, use_lma=True).cpu()
real = a(q, kv, biases=biases).cpu()
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
err = torch.max(torch.abs(l - real))
self.assertTrue(err < consts.eps, f'Error: {err}')
if __name__ == "__main__":
unittest.main()
unittest.main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment