"examples/llm/vscode:/vscode.git/clone" did not exist on "cce0c0287f302bc2b3d562e6f207007776a8310f"
Commit 9aebc203 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update the returned features when it's not training mode in multimer input pipeline

parent 68389359
......@@ -113,14 +113,17 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
return transforms
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False):
"""Based on the config, apply filters and transformations to the data."""
GROUNDTRUTH_FEATURES=['all_atom_mask', 'all_atom_positions','asym_id','sym_id','entity_id']
tensors['aatype'] = tensors['aatype'].to(torch.long)
gt_tensors = {k:v for k,v in tensors.items() if k in GROUNDTRUTH_FEATURES}
gt_tensors['aatype'] = tensors['aatype']
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
tensors['aatype'] = tensors['aatype'].to(torch.long)
nonensembled = nonensembled_transform_fns()
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
......@@ -134,20 +137,19 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
d["ensemble_index"] = i
return fn(d)
nonensembled = nonensembled_transform_fns()
gt_tensors = compose(grountruth_transforms_fns())(gt_tensors)
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
)
return tensors,gt_tensors
if is_training:
GROUNDTRUTH_FEATURES=['all_atom_mask', 'all_atom_positions','asym_id','sym_id','entity_id']
gt_tensors = {k:v for k,v in tensors.items() if k in GROUNDTRUTH_FEATURES}
gt_tensors['aatype'] = tensors['aatype']
gt_tensors = compose(grountruth_transforms_fns())(gt_tensors)
return tensors,gt_tensors
else:
return tensors
@data_transforms.curry1
def compose(x, fs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment