Commit 1b4fc115 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

fixed the input data generation so that true_msa and bert_mask shapes are...

fixed the input data generation so that true_msa and bert_mask shapes are correct otherwise crashes in loss calculations
parent 4f50aadc
...@@ -92,8 +92,8 @@ class TestPermutation: ...@@ -92,8 +92,8 @@ class TestPermutation:
batch['backbone_rigid_mask'] = backbone_dict['backbone_affine_mask'] batch['backbone_rigid_mask'] = backbone_dict['backbone_affine_mask']
true_msa_dict ={ true_msa_dict ={
"true_msa": torch.tensor(np.random.randint(0, 21, (n_res, n_seq))), "true_msa": torch.tensor(np.random.randint(0, 21, (n_seq,n_res))),
"bert_mask": torch.tensor(np.random.randint(0, 2, (n_res, n_seq)).astype( "bert_mask": torch.tensor(np.random.randint(0, 2, (n_seq,n_res)).astype(
np.float32) np.float32)
) )
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment