Commit 61c7640c authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

move ground truth preparation out of process_tensors_from_config

parent a9cf892f
...@@ -112,10 +112,21 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -112,10 +112,21 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
return transforms return transforms
def prepare_ground_truth_features(tensors):
"""Prepare ground truth features that are only needed for loss calculation during training"""
GROUNDTRUTH_FEATURES=['all_atom_mask', 'all_atom_positions','asym_id','sym_id','entity_id']
gt_tensors = {k:v for k,v in tensors.items() if k in GROUNDTRUTH_FEATURES}
gt_tensors['aatype'] = tensors['aatype'].to(torch.long)
gt_tensors = compose(grountruth_transforms_fns())(gt_tensors)
return gt_tensors
def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False): def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
if is_training:
gt_tensors= prepare_ground_truth_features(tensors)
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max) ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
tensors['aatype'] = tensors['aatype'].to(torch.long) tensors['aatype'] = tensors['aatype'].to(torch.long)
nonensembled = nonensembled_transform_fns() nonensembled = nonensembled_transform_fns()
...@@ -142,11 +153,6 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False) ...@@ -142,11 +153,6 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False)
) )
if is_training: if is_training:
GROUNDTRUTH_FEATURES=['all_atom_mask', 'all_atom_positions','asym_id','sym_id','entity_id']
gt_tensors = {k:v for k,v in tensors.items() if k in GROUNDTRUTH_FEATURES}
gt_tensors['aatype'] = tensors['aatype']
gt_tensors = compose(grountruth_transforms_fns())(gt_tensors)
return tensors,gt_tensors return tensors,gt_tensors
else: else:
return tensors return tensors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment