"components/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "2ee29443b6ec02d875d1e5bdf1c47667123f4be4"
Commit 753fc31f authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

turn all_seq_features from numpy to tensor and move all_seq_featurs to cuda

parent 9cab17c4
......@@ -21,7 +21,9 @@ from openfold.data import (
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
import contextlib
import tempfile
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
@contextlib.contextmanager
def temp_fasta_file(sequence_str):
"""function that create temparory fasta file used in multimer datapipeline"""
......@@ -212,7 +214,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
name = self.idx_to_chain_id(idx)
print(f"name is {name}")
alignment_dir = os.path.join(self.alignment_dir, name)
alignment_index = None
......@@ -476,13 +477,18 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
fasta_str+=f">{mmcif_id}_{c}\n{s}\n"
with temp_fasta_file(fasta_str) as fasta_file:
all_chain_features = self.multimer_data_pipeline.process_fasta(fasta_file,self.alignment_dir)
for k,v in all_chain_features.items():
all_chain_features[k] = torch.tensor(v)
move_to_cuda = lambda t: t.to('cuda')
## move all_chain_features to gpu
all_chain_features = tensor_tree_map(move_to_cuda,all_chain_features)
alignment_index = None
ground_truth=[]
if(self.mode == 'train' or self.mode == 'eval'):
for chain in chains:
path = os.path.join(self.alignment_dir, f"{mmcif_id}_{chain.upper()}")
print(f"path is {path}")
path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
......@@ -493,21 +499,22 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
raise ValueError("Invalid file type")
path += ext
alignment_dir = os.path.join(self.alignment_dir,f"{mmcif_id}_{chain.upper()}")
if(ext == ".cif"):
data = self._parse_mmcif(
path, mmcif_id, chain, self.alignment_dir, alignment_index,
path, mmcif_id, chain, alignment_dir, alignment_index,
)
ground_truth.append(data)
elif(ext == ".core"):
data = self.data_pipeline.process_core(
path, self.alignment_dir, alignment_index,
path, alignment_dir, alignment_index,
)
ground_truth.append(data)
elif(ext == ".pdb"):
structure_index = None
data = self.data_pipeline.process_pdb(
pdb_path=path,
alignment_dir=self.alignment_dir,
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain,
alignment_index=alignment_index,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment