Commit 1606ac08 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merge branch 'main' into multimer

parents 58d65692 67a00a6c
...@@ -15,33 +15,39 @@ ...@@ -15,33 +15,39 @@
import torch import torch
import unittest import unittest
from openfold.model.primitives import Attention from openfold.model.primitives import (
lecun_normal_init_,
Attention,
)
from tests.config import consts from tests.config import consts
from tests.data_utils import random_attention_inputs
class TestLMA(unittest.TestCase): class TestLMA(unittest.TestCase):
def test_lma_vs_attention(self): def test_lma_vs_attention(self):
batch_size = consts.batch_size
c_hidden = 32 c_hidden = 32
n = 2 ** 12
no_heads = 4 no_heads = 4
q = torch.rand(batch_size, n, c_hidden).cuda() q, kv, _, biases = random_attention_inputs(batch_size=consts.batch_size,
kv = torch.rand(batch_size, n, c_hidden).cuda() n_seq=consts.n_seq,
n=2 ** 12,
bias = [torch.rand(no_heads, 1, n)] no_heads=no_heads,
bias = [b.cuda() for b in bias] c_hidden=c_hidden)
a = Attention( a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda() ).cuda()
with torch.no_grad(): with torch.no_grad():
l = a(q, kv, biases=bias, use_lma=True) lecun_normal_init_(a.linear_g.weight)
real = a(q, kv, biases=bias) lecun_normal_init_(a.linear_o.weight)
l = a(q, kv, biases=biases, use_lma=True).cpu()
real = a(q, kv, biases=biases).cpu()
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps) err = torch.max(torch.abs(l - real))
self.assertTrue(err < consts.eps, f'Error: {err}')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment