"docs/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "670661f6fa85f6ffc77433d21b363285d6cec32f"
Commit 753fc31f authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

turn all_seq_features from numpy to tensor and move all_seq_featurs to cuda

parent 9cab17c4
...@@ -21,7 +21,9 @@ from openfold.data import ( ...@@ -21,7 +21,9 @@ from openfold.data import (
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
import contextlib import contextlib
import tempfile import tempfile
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
@contextlib.contextmanager @contextlib.contextmanager
def temp_fasta_file(sequence_str): def temp_fasta_file(sequence_str):
"""function that create temparory fasta file used in multimer datapipeline""" """function that create temparory fasta file used in multimer datapipeline"""
...@@ -212,7 +214,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -212,7 +214,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
name = self.idx_to_chain_id(idx) name = self.idx_to_chain_id(idx)
print(f"name is {name}")
alignment_dir = os.path.join(self.alignment_dir, name) alignment_dir = os.path.join(self.alignment_dir, name)
alignment_index = None alignment_index = None
...@@ -476,13 +477,18 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -476,13 +477,18 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
fasta_str+=f">{mmcif_id}_{c}\n{s}\n" fasta_str+=f">{mmcif_id}_{c}\n{s}\n"
with temp_fasta_file(fasta_str) as fasta_file: with temp_fasta_file(fasta_str) as fasta_file:
all_chain_features = self.multimer_data_pipeline.process_fasta(fasta_file,self.alignment_dir) all_chain_features = self.multimer_data_pipeline.process_fasta(fasta_file,self.alignment_dir)
for k,v in all_chain_features.items():
all_chain_features[k] = torch.tensor(v)
move_to_cuda = lambda t: t.to('cuda')
## move all_chain_features to gpu
all_chain_features = tensor_tree_map(move_to_cuda,all_chain_features)
alignment_index = None alignment_index = None
ground_truth=[] ground_truth=[]
if(self.mode == 'train' or self.mode == 'eval'): if(self.mode == 'train' or self.mode == 'eval'):
for chain in chains: for chain in chains:
path = os.path.join(self.alignment_dir, f"{mmcif_id}_{chain.upper()}") path = os.path.join(self.data_dir, f"{mmcif_id}")
print(f"path is {path}")
ext = None ext = None
for e in self.supported_exts: for e in self.supported_exts:
if(os.path.exists(path + e)): if(os.path.exists(path + e)):
...@@ -493,21 +499,22 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -493,21 +499,22 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
raise ValueError("Invalid file type") raise ValueError("Invalid file type")
path += ext path += ext
alignment_dir = os.path.join(self.alignment_dir,f"{mmcif_id}_{chain.upper()}")
if(ext == ".cif"): if(ext == ".cif"):
data = self._parse_mmcif( data = self._parse_mmcif(
path, mmcif_id, chain, self.alignment_dir, alignment_index, path, mmcif_id, chain, alignment_dir, alignment_index,
) )
ground_truth.append(data) ground_truth.append(data)
elif(ext == ".core"): elif(ext == ".core"):
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path, self.alignment_dir, alignment_index, path, alignment_dir, alignment_index,
) )
ground_truth.append(data) ground_truth.append(data)
elif(ext == ".pdb"): elif(ext == ".pdb"):
structure_index = None structure_index = None
data = self.data_pipeline.process_pdb( data = self.data_pipeline.process_pdb(
pdb_path=path, pdb_path=path,
alignment_dir=self.alignment_dir, alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation, is_distillation=self.treat_pdb_as_distillation,
chain_id=chain, chain_id=chain,
alignment_index=alignment_index, alignment_index=alignment_index,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment