Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
753fc31f
"docs/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "670661f6fa85f6ffc77433d21b363285d6cec32f"
Commit
753fc31f
authored
Jul 11, 2023
by
Geoffrey Yu
Browse files
turn all_seq_features from numpy to tensor and move all_seq_featurs to cuda
parent
9cab17c4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
7 deletions
+14
-7
openfold/data/data_modules.py
openfold/data/data_modules.py
+14
-7
No files found.
openfold/data/data_modules.py
View file @
753fc31f
...
@@ -21,7 +21,9 @@ from openfold.data import (
...
@@ -21,7 +21,9 @@ from openfold.data import (
from
openfold.utils.tensor_utils
import
tensor_tree_map
,
dict_multimap
from
openfold.utils.tensor_utils
import
tensor_tree_map
,
dict_multimap
import
contextlib
import
contextlib
import
tempfile
import
tempfile
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
)
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
temp_fasta_file
(
sequence_str
):
def
temp_fasta_file
(
sequence_str
):
"""function that create temparory fasta file used in multimer datapipeline"""
"""function that create temparory fasta file used in multimer datapipeline"""
...
@@ -212,7 +214,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -212,7 +214,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
name
=
self
.
idx_to_chain_id
(
idx
)
name
=
self
.
idx_to_chain_id
(
idx
)
print
(
f
"name is
{
name
}
"
)
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
name
)
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
name
)
alignment_index
=
None
alignment_index
=
None
...
@@ -476,13 +477,18 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -476,13 +477,18 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
fasta_str
+=
f
">
{
mmcif_id
}
_
{
c
}
\n
{
s
}
\n
"
fasta_str
+=
f
">
{
mmcif_id
}
_
{
c
}
\n
{
s
}
\n
"
with
temp_fasta_file
(
fasta_str
)
as
fasta_file
:
with
temp_fasta_file
(
fasta_str
)
as
fasta_file
:
all_chain_features
=
self
.
multimer_data_pipeline
.
process_fasta
(
fasta_file
,
self
.
alignment_dir
)
all_chain_features
=
self
.
multimer_data_pipeline
.
process_fasta
(
fasta_file
,
self
.
alignment_dir
)
for
k
,
v
in
all_chain_features
.
items
():
all_chain_features
[
k
]
=
torch
.
tensor
(
v
)
move_to_cuda
=
lambda
t
:
t
.
to
(
'cuda'
)
## move all_chain_features to gpu
all_chain_features
=
tensor_tree_map
(
move_to_cuda
,
all_chain_features
)
alignment_index
=
None
alignment_index
=
None
ground_truth
=
[]
ground_truth
=
[]
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
):
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
):
for
chain
in
chains
:
for
chain
in
chains
:
path
=
os
.
path
.
join
(
self
.
alignment_dir
,
f
"
{
mmcif_id
}
_
{
chain
.
upper
()
}
"
)
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
"
)
print
(
f
"path is
{
path
}
"
)
ext
=
None
ext
=
None
for
e
in
self
.
supported_exts
:
for
e
in
self
.
supported_exts
:
if
(
os
.
path
.
exists
(
path
+
e
)):
if
(
os
.
path
.
exists
(
path
+
e
)):
...
@@ -493,21 +499,22 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -493,21 +499,22 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
raise
ValueError
(
"Invalid file type"
)
raise
ValueError
(
"Invalid file type"
)
path
+=
ext
path
+=
ext
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
f
"
{
mmcif_id
}
_
{
chain
.
upper
()
}
"
)
if
(
ext
==
".cif"
):
if
(
ext
==
".cif"
):
data
=
self
.
_parse_mmcif
(
data
=
self
.
_parse_mmcif
(
path
,
mmcif_id
,
chain
,
self
.
alignment_dir
,
alignment_index
,
path
,
mmcif_id
,
chain
,
alignment_dir
,
alignment_index
,
)
)
ground_truth
.
append
(
data
)
ground_truth
.
append
(
data
)
elif
(
ext
==
".core"
):
elif
(
ext
==
".core"
):
data
=
self
.
data_pipeline
.
process_core
(
data
=
self
.
data_pipeline
.
process_core
(
path
,
self
.
alignment_dir
,
alignment_index
,
path
,
alignment_dir
,
alignment_index
,
)
)
ground_truth
.
append
(
data
)
ground_truth
.
append
(
data
)
elif
(
ext
==
".pdb"
):
elif
(
ext
==
".pdb"
):
structure_index
=
None
structure_index
=
None
data
=
self
.
data_pipeline
.
process_pdb
(
data
=
self
.
data_pipeline
.
process_pdb
(
pdb_path
=
path
,
pdb_path
=
path
,
alignment_dir
=
self
.
alignment_dir
,
alignment_dir
=
alignment_dir
,
is_distillation
=
self
.
treat_pdb_as_distillation
,
is_distillation
=
self
.
treat_pdb_as_distillation
,
chain_id
=
chain
,
chain_id
=
chain
,
alignment_index
=
alignment_index
,
alignment_index
=
alignment_index
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment