Commit 1e268fd5 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update test codes so that it generates correct structure of input data

parent 6eb1afe7
...@@ -48,7 +48,7 @@ class TestPermutation: ...@@ -48,7 +48,7 @@ class TestPermutation:
""" """
self.test_data_dir = os.path.join(os.getcwd(),"tests/test_data") self.test_data_dir = os.path.join(os.getcwd(),"tests/test_data")
self.label_ids = ['label_1','label_1','label_2','label_2','label_2'] self.label_ids = ['label_1','label_1','label_2','label_2','label_2']
self.asym_id = [1]*9+[2]*9+[3]*13+[4]*13 + [5]*13 self.asym_id = [0]*9+[1]*9+[2]*13+[3]*13 + [4]*13
def affine_vector_to_4x4(self,affine): def affine_vector_to_4x4(self,affine):
r = Rigid.from_tensor_7(affine) r = Rigid.from_tensor_7(affine)
...@@ -111,6 +111,7 @@ class TestPermutation: ...@@ -111,6 +111,7 @@ class TestPermutation:
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float() batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
batch.update(data_transforms.make_atom14_masks(batch)) batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.) batch["no_recycling_iters"] = torch.tensor(2.)
batch["seq_length"] = torch.from_numpy(np.array([n_res] * n_res))
if consts.is_multimer: if consts.is_multimer:
# #
...@@ -120,7 +121,7 @@ class TestPermutation: ...@@ -120,7 +121,7 @@ class TestPermutation:
asym_id = self.asym_id asym_id = self.asym_id
batch["asym_id"] = torch.tensor(asym_id,dtype=torch.float64) batch["asym_id"] = torch.tensor(asym_id,dtype=torch.float64)
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,)) # batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch['entity_id'] = torch.tensor([1]*18+[2]*39,dtype=torch.float64) batch['entity_id'] = torch.tensor([0]*18+[1]*39,dtype=torch.float64)
batch["sym_id"] = torch.tensor(asym_id,dtype=torch.float64) batch["sym_id"] = torch.tensor(asym_id,dtype=torch.float64)
# batch["num_sym"] = torch.tensor([1]*18+[2]*13,dtype=torch.int64) # currently there are just 2 chains # batch["num_sym"] = torch.tensor([1]*18+[2]*13,dtype=torch.int64) # currently there are just 2 chains
batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res)) batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment