Commit 0aa69474 authored by Jennifer's avatar Jennifer
Browse files

fix deepspeed_evo_attention to work in both monomer and multimer settings.

parent 204ed191
...@@ -244,9 +244,6 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -244,9 +244,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32) pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()} batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
template_feats = { template_feats = {
k: v for k, v in batch.items() if k.startswith("template_") k: v for k, v in batch.items() if k.startswith("template_")
...@@ -309,7 +306,8 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -309,7 +306,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
batch["residx_atom37_to_atom14"] = batch[ batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14" "residx_atom37_to_atom14"
].long() ].long()
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], 21).to(torch.float32) # print(batch["target_feat"].shape)
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32)
batch["template_all_atom_mask"] = batch["template_all_atom_masks"] batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update( batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch) data_transforms.atom37_to_torsion_angles("template_")(batch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment