"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "7f9f5326be993cbdeffb51f946f13fd43f75d65b"
Commit 40d9e7d7 authored by Christina Floristean's avatar Christina Floristean
Browse files

Additional fix for multimer deepspeed test

parent 9a07b7f9
...@@ -293,6 +293,15 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -293,6 +293,15 @@ class TestDeepSpeedKernel(unittest.TestCase):
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0] batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
batch["no_recycling_iters"] = np.array([3., 3., 3., 3., ]) batch["no_recycling_iters"] = np.array([3., 3., 3., 3., ])
if consts.is_multimer:
n_res = batch['aatype'].shape[1]
n_extra_seq = batch['extra_msa'].shape[1]
batch["asym_id"] = np.ones((4, n_res))
batch["entity_id"] = np.ones((4, n_res))
batch["sym_id"] = np.ones((4, n_res))
batch["extra_deletion_matrix"] = np.random.randint(0, 2, size=(4, n_extra_seq, n_res))
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()} batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
batch["aatype"] = batch["aatype"].long() batch["aatype"] = batch["aatype"].long()
...@@ -301,6 +310,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -301,6 +310,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
batch["residx_atom37_to_atom14"] = batch[ batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14" "residx_atom37_to_atom14"
].long() ].long()
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], 21).to(torch.float32)
batch["template_all_atom_mask"] = batch["template_all_atom_masks"] batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update( batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch) data_transforms.atom37_to_torsion_angles("template_")(batch)
...@@ -309,7 +319,6 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -309,7 +319,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
# Move the recycling dimension to the end # Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0) move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch) batch = tensor_tree_map(move_dim, batch)
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16): with torch.cuda.amp.autocast(dtype=torch.bfloat16):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
......
...@@ -27,6 +27,7 @@ from tests.config import consts ...@@ -27,6 +27,7 @@ from tests.config import consts
from tests.data_utils import ( from tests.data_utils import (
random_template_feats, random_template_feats,
random_extra_msa_feats, random_extra_msa_feats,
random_asym_ids
) )
if compare_utils.alphafold_is_installed(): if compare_utils.alphafold_is_installed():
...@@ -85,9 +86,9 @@ class TestModel(unittest.TestCase): ...@@ -85,9 +86,9 @@ class TestModel(unittest.TestCase):
batch["no_recycling_iters"] = torch.tensor(2.) batch["no_recycling_iters"] = torch.tensor(2.)
if consts.is_multimer: if consts.is_multimer:
batch["asym_id"] = torch.randint(0, 1, size=(n_res,)) batch["asym_id"] = torch.as_tensor(random_asym_ids(n_res))
batch["entity_id"] = torch.randint(0, 1, size=(n_res,)) batch["entity_id"] = batch["asym_id"].clone()
batch["sym_id"] = torch.randint(0, 1, size=(n_res,)) batch["sym_id"] = torch.ones(n_res)
batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res)) batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res))
add_recycling_dims = lambda t: ( add_recycling_dims = lambda t: (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment