Commit e4d7f6d2 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

need to change ground truth load mode to train to avoid loss error

parent 08a5be86
......@@ -503,8 +503,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
data = self._parse_mmcif(
path, mmcif_id, chain, alignment_dir, alignment_index,
)
# since it's ground truth features, change the mode to eval in order to avoid padding
ground_truth_feats = self.feature_pipeline.process_features(data, "eval",
ground_truth_feats = self.feature_pipeline.process_features(data, "train",
is_multimer=False)
#remove recycling dimension
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
......@@ -513,7 +512,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index,
)
ground_truth_feats = self.feature_pipeline.process_features(data, "eval",
ground_truth_feats = self.feature_pipeline.process_features(data, "train",
is_multimer=False)
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats)
......@@ -527,7 +526,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
alignment_index=alignment_index,
_structure_index=structure_index,
)
ground_truth_feats = self.feature_pipeline.process_features(data, "eval",
ground_truth_feats = self.feature_pipeline.process_features(data, "train",
is_multimer=False)
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment