"deploy/vscode:/vscode.git/clone" did not exist on "86bc5442b4171d9a7c3de4b854dd07ca1b7a4f65"
Unverified Commit 60d0b15a authored by jnwei's avatar jnwei Committed by GitHub
Browse files

Merge pull request #350 from aqlaboratory/fix-msastack-test-error

Fixes cuda/float wrapper error in unit tests
parents 2134cc09 73ff40b6
......@@ -46,4 +46,4 @@ echo "Downloading AlphaFold parameters..."
bash scripts/download_alphafold_params.sh openfold/resources
# Decompress test data
gunzip tests/test_data/sample_feats.pickle.gz
gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats.pickle
......@@ -206,7 +206,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res,
),
device="cuda",
)
).float()
pair_mask = torch.randint(
0,
2,
......@@ -216,7 +216,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res,
),
device="cuda",
)
).float()
shape_z_before = z.shape
......
......@@ -47,27 +47,27 @@ class TestModel(unittest.TestCase):
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
model = AlphaFold(c)
model = AlphaFold(c).cuda()
model.eval()
batch = {}
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,))
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,)).cuda()
batch["target_feat"] = nn.functional.one_hot(
tf, c.model.input_embedder.tf_dim
).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim))
).float().cuda()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1).cuda()
batch["residue_index"] = torch.arange(n_res).cuda()
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)).cuda()
t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
batch.update({k: torch.tensor(v).cuda() for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()})
batch.update({k: torch.tensor(v).cuda() for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint(
low=0, high=2, size=(n_seq, n_res)
).float()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
).float().cuda()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float().cuda()
batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.)
batch["no_recycling_iters"] = torch.tensor(2.).cuda()
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment