Commit 08bfb1ff authored by Dingquan Yu's avatar Dingquan Yu
Browse files

reverse back to multimer branch version

parent 28b9e2b6
......@@ -432,7 +432,7 @@ config = mlc.ConfigDict(
"use_small_bfd": False,
"data_loaders": {
"batch_size": 1,
"num_workers": 1,
"num_workers": 16,
"pin_memory": True,
},
},
......@@ -764,7 +764,7 @@ multimer_config_update = mlc.ConfigDict({
],
"true_msa": [NUM_MSA_SEQ, NUM_RES]
},
"max_recycling_iters": 1, # For training, value is 3
"max_recycling_iters": 20, # For training, value is 3
"unsupervised_features": [
"aatype",
"residue_index",
......@@ -799,7 +799,7 @@ multimer_config_update = mlc.ConfigDict({
"train": {
"max_msa_clusters": 508,
"max_extra_msa": 2048,
"crop_size": 32,
"crop_size": 640,
"spatial_crop_prob": 0.5,
"interface_threshold": 10.,
"clamp_prob": 1.,
......
......@@ -195,6 +195,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_index=alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled
)
return data
def chain_id_to_idx(self, chain_id):
......@@ -422,21 +423,24 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
import time
mmcif_object = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if mmcif_object.mmcif_object is None:
raise list(mmcif_object.errors.values())[0]
mmcif_object = mmcif_object.mmcif_object
data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object,
alignment_dir=alignment_dir,
alignment_index=alignment_index
)
return data
def mmcif_id_to_idx(self, mmcif_id):
......@@ -450,8 +454,6 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
alignment_index = None
if self.mode == 'train' or self.mode == 'eval':
import time
start = time.time()
path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None
for e in self.supported_exts:
......@@ -476,6 +478,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
fasta_path=path,
alignment_dir=self.alignment_dir
)
if self._output_raw:
return data
......@@ -738,7 +741,9 @@ class OpenFoldMultimerDataset(OpenFoldDataset):
generator=self.generator,
)
samples = samples.squeeze()
cache = [i for i, s in zip(idx, samples) if s]
for datapoint_idx in cache:
yield datapoint_idx
......
......@@ -20,7 +20,7 @@ import itertools
import re
import string
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
import asyncio
DeletionMatrix = Sequence[Sequence[int]]
......@@ -120,9 +120,9 @@ def parse_stockholm(stockholm_string: str) -> Msa:
line = line.strip()
if not line or line.startswith(("#", "//")):
continue
name, sequence = line.split(maxsplit=1)
name, sequence = line.split()
if name not in name_to_sequence:
name_to_sequence.setdefault(name,"")
name_to_sequence[name] = ""
name_to_sequence[name] += sequence
msa = []
......
......@@ -42,9 +42,6 @@ from scripts.zero_to_fp32 import (
from openfold.utils.logger import PerformanceLoggingCallback
def calculate_elapse(start, end):
elapse = end - start
print(f"this function runs {round(elapse,3)} seconds i.e. {round(elapse/60, 3)} minutes")
class OpenFoldWrapper(pl.LightningModule):
def __init__(self, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment