Commit 324b2ea6 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update input pipeline

parent fd748a0d
...@@ -36,7 +36,7 @@ def grountruth_transforms_fns(): ...@@ -36,7 +36,7 @@ def grountruth_transforms_fns():
) )
return transforms return transforms
def nonensembled_transform_fns(common_cfg, mode_cfg): def nonensembled_transform_fns():
"""Input pipeline data transformers that are not ensembled.""" """Input pipeline data transformers that are not ensembled."""
transforms = [ transforms = [
data_transforms.cast_to_64bit_ints, data_transforms.cast_to_64bit_ints,
...@@ -120,13 +120,10 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -120,13 +120,10 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
def process_tensors_from_config(tensors, common_cfg, mode_cfg): def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
GROUNDTRUTH_FEATURES=['all_atom_mask', 'all_atom_positions'] GROUNDTRUTH_FEATURES=['all_atom_mask', 'all_atom_positions','asym_id','sym_id','entity_id']
COMMON_FEATURES=['asym_id','sym_id','entity_id'] tensors['aatype'] = tensors['aatype'].to(torch.long)
input_tensors = {k:v for k,v in tensors.items() if k not in GROUNDTRUTH_FEATURES} gt_tensors = {k:v for k,v in tensors.items() if k in GROUNDTRUTH_FEATURES}
gt_tensors = {k:v for k,v in tensors.items() if k in GROUNDTRUTH_FEATURES or k in COMMON_FEATURES} gt_tensors['aatype'] = tensors['aatype']
gt_tensors['aatype'] = tensors['aatype'].to(torch.long)
del tensors
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max) ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
def wrap_ensemble_fn(data, i): def wrap_ensemble_fn(data, i):
...@@ -147,17 +144,16 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg): ...@@ -147,17 +144,16 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
) )
gt_tensors = compose(grountruth_transforms_fns())(gt_tensors) gt_tensors = compose(grountruth_transforms_fns())(gt_tensors)
input_tensors = compose(nonensembled)(input_tensors) tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in input_tensors): if("no_recycling_iters" in tensors):
num_recycling = int(input_tensors["no_recycling_iters"]) num_recycling = int(tensors["no_recycling_iters"])
else: else:
num_recycling = common_cfg.max_recycling_iters num_recycling = common_cfg.max_recycling_iters
input_tensors = map_fn( tensors = map_fn(
lambda x: wrap_ensemble_fn(input_tensors, x), torch.arange(num_recycling + 1) lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
) )
return tensors,gt_tensors
return input_tensors,gt_tensors
@data_transforms.curry1 @data_transforms.curry1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment