"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "5ddc7f7df5ab77c4efae9fd6ca299c3040c91533"
Commit 08bfb1ff authored by Dingquan Yu's avatar Dingquan Yu
Browse files

reverse back to multimer branch version

parent 28b9e2b6
...@@ -432,7 +432,7 @@ config = mlc.ConfigDict( ...@@ -432,7 +432,7 @@ config = mlc.ConfigDict(
"use_small_bfd": False, "use_small_bfd": False,
"data_loaders": { "data_loaders": {
"batch_size": 1, "batch_size": 1,
"num_workers": 1, "num_workers": 16,
"pin_memory": True, "pin_memory": True,
}, },
}, },
...@@ -764,7 +764,7 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -764,7 +764,7 @@ multimer_config_update = mlc.ConfigDict({
], ],
"true_msa": [NUM_MSA_SEQ, NUM_RES] "true_msa": [NUM_MSA_SEQ, NUM_RES]
}, },
"max_recycling_iters": 1, # For training, value is 3 "max_recycling_iters": 20, # For training, value is 3
"unsupervised_features": [ "unsupervised_features": [
"aatype", "aatype",
"residue_index", "residue_index",
...@@ -799,7 +799,7 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -799,7 +799,7 @@ multimer_config_update = mlc.ConfigDict({
"train": { "train": {
"max_msa_clusters": 508, "max_msa_clusters": 508,
"max_extra_msa": 2048, "max_extra_msa": 2048,
"crop_size": 32, "crop_size": 640,
"spatial_crop_prob": 0.5, "spatial_crop_prob": 0.5,
"interface_threshold": 10., "interface_threshold": 10.,
"clamp_prob": 1., "clamp_prob": 1.,
......
...@@ -195,6 +195,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -195,6 +195,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_index=alignment_index, alignment_index=alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled seqemb_mode=self.config.seqemb_mode.enabled
) )
return data return data
def chain_id_to_idx(self, chain_id): def chain_id_to_idx(self, chain_id):
...@@ -422,21 +423,24 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -422,21 +423,24 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index): def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index):
with open(path, 'r') as f: with open(path, 'r') as f:
mmcif_string = f.read() mmcif_string = f.read()
import time
mmcif_object = mmcif_parsing.parse( mmcif_object = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string file_id=file_id, mmcif_string=mmcif_string
) )
# Crash if an error is encountered. Any parsing errors should have # Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage. # been dealt with at the alignment stage.
if mmcif_object.mmcif_object is None: if mmcif_object.mmcif_object is None:
raise list(mmcif_object.errors.values())[0] raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object mmcif_object = mmcif_object.mmcif_object
data = self.data_pipeline.process_mmcif( data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
alignment_index=alignment_index alignment_index=alignment_index
) )
return data return data
def mmcif_id_to_idx(self, mmcif_id): def mmcif_id_to_idx(self, mmcif_id):
...@@ -450,8 +454,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -450,8 +454,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
alignment_index = None alignment_index = None
if self.mode == 'train' or self.mode == 'eval': if self.mode == 'train' or self.mode == 'eval':
import time
start = time.time()
path = os.path.join(self.data_dir, f"{mmcif_id}") path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None ext = None
for e in self.supported_exts: for e in self.supported_exts:
...@@ -476,6 +478,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -476,6 +478,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
fasta_path=path, fasta_path=path,
alignment_dir=self.alignment_dir alignment_dir=self.alignment_dir
) )
if self._output_raw: if self._output_raw:
return data return data
...@@ -483,7 +486,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -483,7 +486,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
data = self.feature_pipeline.process_features(data, data = self.feature_pipeline.process_features(data,
mode=self.mode, mode=self.mode,
is_multimer=True) is_multimer=True)
# if it's inference mode, only need all_chain_features # if it's inference mode, only need all_chain_features
data["batch_idx"] = torch.tensor( data["batch_idx"] = torch.tensor(
[idx for _ in range(data["aatype"].shape[-1])], [idx for _ in range(data["aatype"].shape[-1])],
...@@ -738,7 +741,9 @@ class OpenFoldMultimerDataset(OpenFoldDataset): ...@@ -738,7 +741,9 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
generator=self.generator, generator=self.generator,
) )
samples = samples.squeeze() samples = samples.squeeze()
cache = [i for i, s in zip(idx, samples) if s] cache = [i for i, s in zip(idx, samples) if s]
for datapoint_idx in cache: for datapoint_idx in cache:
yield datapoint_idx yield datapoint_idx
......
...@@ -20,7 +20,7 @@ import itertools ...@@ -20,7 +20,7 @@ import itertools
import re import re
import string import string
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
import asyncio
DeletionMatrix = Sequence[Sequence[int]] DeletionMatrix = Sequence[Sequence[int]]
...@@ -120,9 +120,9 @@ def parse_stockholm(stockholm_string: str) -> Msa: ...@@ -120,9 +120,9 @@ def parse_stockholm(stockholm_string: str) -> Msa:
line = line.strip() line = line.strip()
if not line or line.startswith(("#", "//")): if not line or line.startswith(("#", "//")):
continue continue
name, sequence = line.split(maxsplit=1) name, sequence = line.split()
if name not in name_to_sequence: if name not in name_to_sequence:
name_to_sequence.setdefault(name,"") name_to_sequence[name] = ""
name_to_sequence[name] += sequence name_to_sequence[name] += sequence
msa = [] msa = []
......
...@@ -42,9 +42,6 @@ from scripts.zero_to_fp32 import ( ...@@ -42,9 +42,6 @@ from scripts.zero_to_fp32 import (
from openfold.utils.logger import PerformanceLoggingCallback from openfold.utils.logger import PerformanceLoggingCallback
def calculate_elapse(start, end):
elapse = end - start
print(f"this function runs {round(elapse,3)} seconds i.e. {round(elapse/60, 3)} minutes")
class OpenFoldWrapper(pl.LightningModule): class OpenFoldWrapper(pl.LightningModule):
def __init__(self, config): def __init__(self, config):
...@@ -319,7 +316,7 @@ def main(args): ...@@ -319,7 +316,7 @@ def main(args):
batch_seed=args.seed, batch_seed=args.seed,
**vars(args) **vars(args)
) )
data_module.prepare_data() data_module.prepare_data()
data_module.setup() data_module.setup()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment