Commit 1e268fd5 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update test codes so that it generates correct structure of input data

parent 6eb1afe7
......@@ -48,7 +48,7 @@ class TestPermutation:
"""
self.test_data_dir = os.path.join(os.getcwd(),"tests/test_data")
self.label_ids = ['label_1','label_1','label_2','label_2','label_2']
self.asym_id = [1]*9+[2]*9+[3]*13+[4]*13 + [5]*13
self.asym_id = [0]*9+[1]*9+[2]*13+[3]*13 + [4]*13
def affine_vector_to_4x4(self,affine):
r = Rigid.from_tensor_7(affine)
......@@ -111,6 +111,7 @@ class TestPermutation:
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.)
batch["seq_length"] = torch.from_numpy(np.array([n_res] * n_res))
if consts.is_multimer:
#
......@@ -120,7 +121,7 @@ class TestPermutation:
asym_id = self.asym_id
batch["asym_id"] = torch.tensor(asym_id,dtype=torch.float64)
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch['entity_id'] = torch.tensor([1]*18+[2]*39,dtype=torch.float64)
batch['entity_id'] = torch.tensor([0]*18+[1]*39,dtype=torch.float64)
batch["sym_id"] = torch.tensor(asym_id,dtype=torch.float64)
# batch["num_sym"] = torch.tensor([1]*18+[2]*13,dtype=torch.int64) # currently there are just 2 chains
batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment