"lib/runtime/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "08fd28978c1480e5ec07f4dc82d9befa24908230"
Commit 093603ee authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update multimer data input pipeline

parent f3c1af45
...@@ -443,7 +443,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -443,7 +443,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
) )
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index): def _parse_mmcif(self, path, file_id,alignment_dir, alignment_index):
with open(path, 'r') as f: with open(path, 'r') as f:
mmcif_string = f.read() mmcif_string = f.read()
...@@ -462,7 +462,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -462,7 +462,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
alignment_index=alignment_index alignment_index=alignment_index
) )
return data return data
...@@ -471,10 +471,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -471,10 +471,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def idx_to_mmcif_id(self, idx): def idx_to_mmcif_id(self, idx):
return self._mmcifs[idx] return self._mmcifs[idx]
def __getitem__(self, idx): def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx) mmcif_id = self.idx_to_mmcif_id(idx)
alignment_index = None alignment_index = None
if(self.mode == 'train' or self.mode == 'eval'): if(self.mode == 'train' or self.mode == 'eval'):
path = os.path.join(self.data_dir, f"{mmcif_id}") path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None ext = None
...@@ -503,19 +504,19 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -503,19 +504,19 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
if (self._output_raw): if (self._output_raw):
return data return data
# process all_chain_features # process all_chain_features
data = self.feature_pipeline.process_features(data, data,ground_truth = self.feature_pipeline.process_features(data,
mode=self.mode, mode=self.mode,
is_multimer=True) is_multimer=True)
# if it's inference mode, only need all_chain_features # if it's inference mode, only need all_chain_features
data["batch_idx"] = torch.tensor( data["batch_idx"] = torch.tensor(
[idx for _ in range(data["aatype"].shape[-1])], [idx for _ in range(data["aatype"].shape[-1])],
dtype=torch.int64, dtype=torch.int64,
device=data["aatype"].device) device=data["aatype"].device)
return data return data, ground_truth
def __len__(self): def __len__(self):
return len(self._chain_ids) return len(self._chain_ids)
......
...@@ -81,7 +81,7 @@ def np_example_to_features( ...@@ -81,7 +81,7 @@ def np_example_to_features(
seq_length = np_example["seq_length"] seq_length = np_example["seq_length"]
num_res = int(seq_length[0]) if seq_length.ndim != 0 else int(seq_length) num_res = int(seq_length[0]) if seq_length.ndim != 0 else int(seq_length)
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if "deletion_matrix_int" in np_example: if "deletion_matrix_int" in np_example:
np_example["deletion_matrix"] = np_example.pop( np_example["deletion_matrix"] = np_example.pop(
"deletion_matrix_int" "deletion_matrix_int"
...@@ -90,6 +90,7 @@ def np_example_to_features( ...@@ -90,6 +90,7 @@ def np_example_to_features(
tensor_dict = np_to_tensor_dict( tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names np_example=np_example, features=feature_names
) )
with torch.no_grad(): with torch.no_grad():
if(not is_multimer): if(not is_multimer):
features = input_pipeline.process_tensors_from_config( features = input_pipeline.process_tensors_from_config(
...@@ -98,7 +99,7 @@ def np_example_to_features( ...@@ -98,7 +99,7 @@ def np_example_to_features(
cfg[mode], cfg[mode],
) )
else: else:
features = input_pipeline_multimer.process_tensors_from_config( features,gt_features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict, tensor_dict,
cfg.common, cfg.common,
cfg[mode], cfg[mode],
...@@ -119,7 +120,7 @@ def np_example_to_features( ...@@ -119,7 +120,7 @@ def np_example_to_features(
dtype=torch.float32, dtype=torch.float32,
) )
return {k: v for k, v in features.items()} return {k: v for k, v in features.items()},gt_features
class FeaturePipeline: class FeaturePipeline:
......
...@@ -21,19 +21,11 @@ from openfold.data import ( ...@@ -21,19 +21,11 @@ from openfold.data import (
data_transforms_multimer, data_transforms_multimer,
) )
def grountruth_transforms_fns():
def nonensembled_transform_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that are not ensembled.""" transforms = []
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms_multimer.make_msa_profile,
data_transforms_multimer.create_target_feat,
data_transforms.make_atom14_masks,
]
if mode_cfg.supervised:
transforms.extend( transforms.extend(
[ [ data_transforms.make_atom14_masks,
data_transforms.make_atom14_positions, data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames, data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""), data_transforms.atom37_to_torsion_angles(""),
...@@ -42,6 +34,16 @@ def nonensembled_transform_fns(common_cfg, mode_cfg): ...@@ -42,6 +34,16 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.get_chi_angles, data_transforms.get_chi_angles,
] ]
) )
return transforms
def nonensembled_transform_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that are not ensembled."""
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms_multimer.make_msa_profile,
data_transforms_multimer.create_target_feat,
data_transforms.make_atom14_masks
]
return transforms return transforms
...@@ -118,6 +120,11 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -118,6 +120,11 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
def process_tensors_from_config(tensors, common_cfg, mode_cfg): def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
GROUNDTRUTH_FEATURES=['all_atom_mask', 'all_atom_positions']
input_tensors = {k:v for k,v in tensors.items() if k not in GROUNDTRUTH_FEATURES}
gt_tensors = {k:v for k,v in tensors.items() if k in GROUNDTRUTH_FEATURES}
gt_tensors['aatype'] = tensors['aatype'].to(torch.long)
del tensors
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max) ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
def wrap_ensemble_fn(data, i): def wrap_ensemble_fn(data, i):
...@@ -132,27 +139,23 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg): ...@@ -132,27 +139,23 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
d["ensemble_index"] = i d["ensemble_index"] = i
return fn(d) return fn(d)
no_templates = True
if("template_aatype" in tensors):
no_templates = tensors["template_aatype"].shape[0] == 0
nonensembled = nonensembled_transform_fns( nonensembled = nonensembled_transform_fns(
common_cfg, common_cfg,
mode_cfg, mode_cfg,
) )
gt_tensors = compose(grountruth_transforms_fns())(gt_tensors)
tensors = compose(nonensembled)(tensors) input_tensors = compose(nonensembled)(input_tensors)
if("no_recycling_iters" in input_tensors):
if("no_recycling_iters" in tensors): num_recycling = int(input_tensors["no_recycling_iters"])
num_recycling = int(tensors["no_recycling_iters"])
else: else:
num_recycling = common_cfg.max_recycling_iters num_recycling = common_cfg.max_recycling_iters
tensors = map_fn( input_tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1) lambda x: wrap_ensemble_fn(input_tensors, x), torch.arange(num_recycling + 1)
) )
return tensors return input_tensors,gt_tensors
@data_transforms.curry1 @data_transforms.curry1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment