Unverified Commit 60d0b15a authored by jnwei's avatar jnwei Committed by GitHub
Browse files

Merge pull request #350 from aqlaboratory/fix-msastack-test-error

Fixes cuda/float wrapper error in unit tests
parents 2134cc09 73ff40b6
...@@ -46,4 +46,4 @@ echo "Downloading AlphaFold parameters..." ...@@ -46,4 +46,4 @@ echo "Downloading AlphaFold parameters..."
bash scripts/download_alphafold_params.sh openfold/resources bash scripts/download_alphafold_params.sh openfold/resources
# Decompress test data # Decompress test data
gunzip tests/test_data/sample_feats.pickle.gz gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats.pickle
...@@ -206,7 +206,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -206,7 +206,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res, n_res,
), ),
device="cuda", device="cuda",
) ).float()
pair_mask = torch.randint( pair_mask = torch.randint(
0, 0,
2, 2,
...@@ -216,7 +216,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -216,7 +216,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res, n_res,
), ),
device="cuda", device="cuda",
) ).float()
shape_z_before = z.shape shape_z_before = z.shape
......
...@@ -47,27 +47,27 @@ class TestModel(unittest.TestCase): ...@@ -47,27 +47,27 @@ class TestModel(unittest.TestCase):
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test # deepspeed for this test
model = AlphaFold(c) model = AlphaFold(c).cuda()
model.eval() model.eval()
batch = {} batch = {}
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,)) tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,)).cuda()
batch["target_feat"] = nn.functional.one_hot( batch["target_feat"] = nn.functional.one_hot(
tf, c.model.input_embedder.tf_dim tf, c.model.input_embedder.tf_dim
).float() ).float().cuda()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1) batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1).cuda()
batch["residue_index"] = torch.arange(n_res) batch["residue_index"] = torch.arange(n_res).cuda()
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)) batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)).cuda()
t_feats = random_template_feats(n_templ, n_res) t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()}) batch.update({k: torch.tensor(v).cuda() for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(n_extra_seq, n_res) extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()}) batch.update({k: torch.tensor(v).cuda() for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint( batch["msa_mask"] = torch.randint(
low=0, high=2, size=(n_seq, n_res) low=0, high=2, size=(n_seq, n_res)
).float() ).float().cuda()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float() batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float().cuda()
batch.update(data_transforms.make_atom14_masks(batch)) batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.) batch["no_recycling_iters"] = torch.tensor(2.).cuda()
add_recycling_dims = lambda t: ( add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters) t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment