Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
44b0bf76
"vscode:/vscode.git/clone" did not exist on "7b42d1662941866320ec18e5e19ed877c2d11a28"
Commit
44b0bf76
authored
Aug 10, 2023
by
Christina Floristean
Browse files
Merge branch 'main' into multimer
parents
959b3f25
1dc2748c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
19 additions
and
15 deletions
+19
-15
openfold/data/data_modules.py
openfold/data/data_modules.py
+3
-6
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+8
-3
openfold/data/input_pipeline.py
openfold/data/input_pipeline.py
+2
-2
tests/test_evoformer.py
tests/test_evoformer.py
+5
-3
tests/test_model.py
tests/test_model.py
+1
-0
train_openfold.py
train_openfold.py
+0
-1
No files found.
openfold/data/data_modules.py
View file @
44b0bf76
...
@@ -797,10 +797,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
...
@@ -797,10 +797,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
config
=
config
self
.
config
=
config
self
.
stage
=
stage
self
.
stage
=
stage
if
(
generator
is
None
):
generator
=
torch
.
Generator
()
self
.
generator
=
generator
self
.
generator
=
generator
self
.
_prep_batch_properties_probs
()
self
.
_prep_batch_properties_probs
()
...
@@ -1077,8 +1073,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -1077,8 +1073,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
)
def
_gen_dataloader
(
self
,
stage
):
def
_gen_dataloader
(
self
,
stage
):
generator
=
torch
.
Generator
()
generator
=
None
if
(
self
.
batch_seed
is
not
None
):
if
(
self
.
batch_seed
is
not
None
):
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
dataset
=
None
dataset
=
None
...
...
openfold/data/data_transforms.py
View file @
44b0bf76
...
@@ -186,9 +186,12 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
...
@@ -186,9 +186,12 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
def
sample_msa
(
protein
,
max_seq
,
keep_extra
,
seed
=
None
):
def
sample_msa
(
protein
,
max_seq
,
keep_extra
,
seed
=
None
):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
num_seq
=
protein
[
"msa"
].
shape
[
0
]
num_seq
=
protein
[
"msa"
].
shape
[
0
]
g
=
torch
.
Generator
(
device
=
protein
[
"msa"
].
device
)
g
=
None
if
seed
is
not
None
:
if
seed
is
not
None
:
g
=
torch
.
Generator
(
device
=
protein
[
"msa"
].
device
)
g
.
manual_seed
(
seed
)
g
.
manual_seed
(
seed
)
shuffled
=
torch
.
randperm
(
num_seq
-
1
,
generator
=
g
)
+
1
shuffled
=
torch
.
randperm
(
num_seq
-
1
,
generator
=
g
)
+
1
index_order
=
torch
.
cat
(
index_order
=
torch
.
cat
(
(
torch
.
tensor
([
0
],
device
=
shuffled
.
device
),
shuffled
),
(
torch
.
tensor
([
0
],
device
=
shuffled
.
device
),
shuffled
),
...
@@ -1181,8 +1184,10 @@ def random_crop_to_size(
...
@@ -1181,8 +1184,10 @@ def random_crop_to_size(
):
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
# We want each ensemble to be cropped the same way
g
=
torch
.
Generator
(
device
=
protein
[
"seq_length"
].
device
)
g
=
None
if
seed
is
not
None
:
if
seed
is
not
None
:
g
=
torch
.
Generator
(
device
=
protein
[
"seq_length"
].
device
)
g
.
manual_seed
(
seed
)
g
.
manual_seed
(
seed
)
seq_length
=
protein
[
"seq_length"
]
seq_length
=
protein
[
"seq_length"
]
...
...
openfold/data/input_pipeline.py
View file @
44b0bf76
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partial
import
random
import
torch
import
torch
...
@@ -154,7 +154,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
...
@@ -154,7 +154,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
"""Based on the config, apply filters and transformations to the data."""
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed
=
torch
.
Generator
().
seed
(
)
ensemble_seed
=
random
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
)
def
wrap_ensemble_fn
(
data
,
i
):
def
wrap_ensemble_fn
(
data
,
i
):
"""Function to be mapped over the ensemble dimension."""
"""Function to be mapped over the ensemble dimension."""
...
...
tests/test_evoformer.py
View file @
44b0bf76
...
@@ -202,10 +202,10 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -202,10 +202,10 @@ class TestExtraMSAStack(unittest.TestCase):
ckpt
=
False
,
ckpt
=
False
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
).
eval
()
).
eval
()
.
cuda
()
m
=
torch
.
rand
((
batch_size
,
s_t
,
n_res
,
c_m
))
m
=
torch
.
rand
((
batch_size
,
s_t
,
n_res
,
c_m
)
,
device
=
"cuda"
)
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
z
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
)
,
device
=
"cuda"
)
msa_mask
=
torch
.
randint
(
msa_mask
=
torch
.
randint
(
0
,
0
,
2
,
2
,
...
@@ -214,6 +214,7 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -214,6 +214,7 @@ class TestExtraMSAStack(unittest.TestCase):
s_t
,
s_t
,
n_res
,
n_res
,
),
),
device
=
"cuda"
,
)
)
pair_mask
=
torch
.
randint
(
pair_mask
=
torch
.
randint
(
0
,
0
,
...
@@ -223,6 +224,7 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -223,6 +224,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res
,
n_res
,
n_res
,
n_res
,
),
),
device
=
"cuda"
,
)
)
shape_z_before
=
z
.
shape
shape_z_before
=
z
.
shape
...
...
tests/test_model.py
View file @
44b0bf76
...
@@ -61,6 +61,7 @@ class TestModel(unittest.TestCase):
...
@@ -61,6 +61,7 @@ class TestModel(unittest.TestCase):
# deepspeed for this test
# deepspeed for this test
model
=
AlphaFold
(
c
)
model
=
AlphaFold
(
c
)
model
.
eval
()
batch
=
{}
batch
=
{}
tf
=
torch
.
randint
(
c
.
model
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
tf
=
torch
.
randint
(
c
.
model
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
...
...
train_openfold.py
View file @
44b0bf76
...
@@ -88,7 +88,6 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -88,7 +88,6 @@ class OpenFoldWrapper(pl.LightningModule):
)
)
for
k
,
v
in
other_metrics
.
items
():
for
k
,
v
in
other_metrics
.
items
():
assert
(
len
(
v
.
shape
)
==
1
)
self
.
log
(
self
.
log
(
f
"
{
phase
}
/
{
k
}
"
,
f
"
{
phase
}
/
{
k
}
"
,
torch
.
mean
(
v
),
torch
.
mean
(
v
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment