Commit 9ebb1e1a authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

updated the loss config so that the masked msa dim is 22 instead of 23

parent 07421c47
......@@ -63,6 +63,8 @@ class TestPermutation:
n_extra_seq = consts.n_extra
c = model_config(consts.model, train=True)
c.loss.masked_msa.num_classes = 22 # somehow need overwrite this part in multimer loss config
c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
......@@ -89,6 +91,15 @@ class TestPermutation:
batch['backbone_rigid_tensor'] = self.affine_vector_to_4x4(backbone_dict['backbone_affine_tensor'])
batch['backbone_rigid_mask'] = backbone_dict['backbone_affine_mask']
true_msa_dict ={
"true_msa": torch.tensor(np.random.randint(0, 21, (n_res, n_seq))),
"bert_mask": torch.tensor(np.random.randint(0, 2, (n_res, n_seq)).astype(
np.float32)
)
}
batch.update(true_msa_dict)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim))
t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment