Commit e4d7f6d2 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

need to change ground truth load mode to train to avoid loss error

parent 08a5be86
...@@ -503,8 +503,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -503,8 +503,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
data = self._parse_mmcif( data = self._parse_mmcif(
path, mmcif_id, chain, alignment_dir, alignment_index, path, mmcif_id, chain, alignment_dir, alignment_index,
) )
# since it's ground truth features, change the mode to eval in order to avoid padding ground_truth_feats = self.feature_pipeline.process_features(data, "train",
ground_truth_feats = self.feature_pipeline.process_features(data, "eval",
is_multimer=False) is_multimer=False)
#remove recycling dimension #remove recycling dimension
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats) ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
...@@ -513,7 +512,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -513,7 +512,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index, path, alignment_dir, alignment_index,
) )
ground_truth_feats = self.feature_pipeline.process_features(data, "eval", ground_truth_feats = self.feature_pipeline.process_features(data, "train",
is_multimer=False) is_multimer=False)
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats) ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats) ground_truth.append(ground_truth_feats)
...@@ -527,7 +526,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -527,7 +526,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
alignment_index=alignment_index, alignment_index=alignment_index,
_structure_index=structure_index, _structure_index=structure_index,
) )
ground_truth_feats = self.feature_pipeline.process_features(data, "eval", ground_truth_feats = self.feature_pipeline.process_features(data, "train",
is_multimer=False) is_multimer=False)
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats) ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats) ground_truth.append(ground_truth_feats)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment