Commit 777d738a authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for PreembeddingEmbedder

parent 6012b9e1
......@@ -17,6 +17,7 @@ import numpy as np
import unittest
from openfold.model.embedders import (
InputEmbedder,
PreembeddingEmbedder,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
......@@ -46,6 +47,28 @@ class TestInputEmbedder(unittest.TestCase):
self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))
class TestPreembeddingEmbedder(unittest.TestCase):
def test_shape(self):
tf_dim = 22
preembedding_dim = 1280
c_z = 4
c_m = 6
relpos_k = 10
batch_size = 4
num_res = 20
tf = torch.rand((batch_size, num_res, tf_dim))
ri = torch.rand((batch_size, num_res))
preemb = torch.rand((batch_size, num_res, preembedding_dim))
pe = PreembeddingEmbedder(tf_dim, preembedding_dim, c_z, c_m, relpos_k)
seq_emb, pair_emb = pe(tf, ri, preemb)
self.assertTrue(seq_emb.shape == (batch_size, 1, num_res, c_m))
self.assertTrue(pair_emb.shape == (batch_size, num_res, num_res, c_z))
class TestRecyclingEmbedder(unittest.TestCase):
def test_shape(self):
batch_size = 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment