Commit 9ebb1e1a authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

updated the loss config so that the masked msa dim is 22 instead of 23

parent 07421c47
...@@ -63,6 +63,8 @@ class TestPermutation: ...@@ -63,6 +63,8 @@ class TestPermutation:
n_extra_seq = consts.n_extra n_extra_seq = consts.n_extra
c = model_config(consts.model, train=True) c = model_config(consts.model, train=True)
c.loss.masked_msa.num_classes = 22 # somehow need overwrite this part in multimer loss config
c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test # deepspeed for this test
...@@ -88,6 +90,15 @@ class TestPermutation: ...@@ -88,6 +90,15 @@ class TestPermutation:
} }
batch['backbone_rigid_tensor'] = self.affine_vector_to_4x4(backbone_dict['backbone_affine_tensor']) batch['backbone_rigid_tensor'] = self.affine_vector_to_4x4(backbone_dict['backbone_affine_tensor'])
batch['backbone_rigid_mask'] = backbone_dict['backbone_affine_mask'] batch['backbone_rigid_mask'] = backbone_dict['backbone_affine_mask']
true_msa_dict ={
"true_msa": torch.tensor(np.random.randint(0, 21, (n_res, n_seq))),
"bert_mask": torch.tensor(np.random.randint(0, 2, (n_res, n_seq)).astype(
np.float32)
)
}
batch.update(true_msa_dict)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)) batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim))
t_feats = random_template_feats(n_templ, n_res) t_feats = random_template_feats(n_templ, n_res)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment