Commit 8baae516 authored by Jennifer Wei's avatar Jennifer Wei
Browse files

Adds cuda wrapper to pytorch vectors to fix TestModel.test_dry_run

parent 48668ca3
...@@ -47,27 +47,27 @@ class TestModel(unittest.TestCase): ...@@ -47,27 +47,27 @@ class TestModel(unittest.TestCase):
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test # deepspeed for this test
model = AlphaFold(c) model = AlphaFold(c).cuda()
model.eval() model.eval()
batch = {} batch = {}
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,)) tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,)).cuda()
batch["target_feat"] = nn.functional.one_hot( batch["target_feat"] = nn.functional.one_hot(
tf, c.model.input_embedder.tf_dim tf, c.model.input_embedder.tf_dim
).float() ).float().cuda()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1) batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1).cuda()
batch["residue_index"] = torch.arange(n_res) batch["residue_index"] = torch.arange(n_res).cuda()
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)) batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)).cuda()
t_feats = random_template_feats(n_templ, n_res) t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()}) batch.update({k: torch.tensor(v).cuda() for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(n_extra_seq, n_res) extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()}) batch.update({k: torch.tensor(v).cuda() for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint( batch["msa_mask"] = torch.randint(
low=0, high=2, size=(n_seq, n_res) low=0, high=2, size=(n_seq, n_res)
).float() ).float().cuda()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float() batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float().cuda()
batch.update(data_transforms.make_atom14_masks(batch)) batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.) batch["no_recycling_iters"] = torch.tensor(2.).cuda()
add_recycling_dims = lambda t: ( add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters) t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment