Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
30764cf9
"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "06eac9eaa70be4ca9aacd82a323134cfde3604b2"
Commit
30764cf9
authored
Aug 03, 2023
by
Christina Floristean
Browse files
Minor fixes/reformatting for recent multimer training PR
parent
31051cf2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
60 additions
and
92 deletions
+60
-92
openfold/config.py
openfold/config.py
+9
-24
openfold/data/data_modules.py
openfold/data/data_modules.py
+15
-7
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+16
-49
openfold/utils/loss.py
openfold/utils/loss.py
+14
-7
tests/compare_utils.py
tests/compare_utils.py
+2
-1
tests/test_template.py
tests/test_template.py
+4
-4
No files found.
openfold/config.py
View file @
30764cf9
...
@@ -160,10 +160,10 @@ def model_config(
...
@@ -160,10 +160,10 @@ def model_config(
c
.
loss
.
masked_msa
.
num_classes
=
22
c
.
loss
.
masked_msa
.
num_classes
=
22
c
.
data
.
common
.
max_recycling_iters
=
20
c
.
data
.
common
.
max_recycling_iters
=
20
for
k
,
v
in
multimer_model_config_update
[
'model'
].
items
():
for
k
,
v
in
multimer_model_config_update
[
'model'
].
items
():
c
.
model
[
k
]
=
v
c
.
model
[
k
]
=
v
for
k
,
v
in
multimer_model_config_update
[
'loss'
].
items
():
for
k
,
v
in
multimer_model_config_update
[
'loss'
].
items
():
c
.
loss
[
k
]
=
v
c
.
loss
[
k
]
=
v
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
...
@@ -683,24 +683,11 @@ config = mlc.ConfigDict(
...
@@ -683,24 +683,11 @@ config = mlc.ConfigDict(
)
)
multimer_model_config_update
=
{
multimer_model_config_update
=
{
'model'
:{
"input_embedder"
:
{
'model'
:
{
"tf_dim"
:
21
,
"input_embedder"
:
{
"msa_dim"
:
49
,
"tf_dim"
:
21
,
#"num_msa": 508,
"msa_dim"
:
49
,
"c_z"
:
c_z
,
#"num_msa": 508,
"c_m"
:
c_m
,
"relpos_k"
:
32
,
"max_relative_chain"
:
2
,
"max_relative_idx"
:
32
,
"use_chain_relative"
:
True
,
},
"template"
:
{
"distogram"
:
{
"min_bin"
:
3.25
,
"max_bin"
:
50.75
,
"no_bins"
:
39
,
},
"template_pair_embedder"
:
{
"c_z"
:
c_z
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"c_m"
:
c_m
,
"relpos_k"
:
32
,
"relpos_k"
:
32
,
...
@@ -841,8 +828,6 @@ multimer_model_config_update = {
...
@@ -841,8 +828,6 @@ multimer_model_config_update = {
},
},
"recycle_early_stop_tolerance"
:
0.5
"recycle_early_stop_tolerance"
:
0.5
},
},
"recycle_early_stop_tolerance"
:
0.5
},
"loss"
:
{
"loss"
:
{
"distogram"
:
{
"distogram"
:
{
"min_bin"
:
2.3125
,
"min_bin"
:
2.3125
,
...
@@ -863,7 +848,7 @@ multimer_model_config_update = {
...
@@ -863,7 +848,7 @@ multimer_model_config_update = {
"loss_unit_distance"
:
10.0
,
"loss_unit_distance"
:
10.0
,
"weight"
:
0.5
,
"weight"
:
0.5
,
},
},
"interface"
:
{
"interface
_backbone
"
:
{
"clamp_distance"
:
30.0
,
"clamp_distance"
:
30.0
,
"loss_unit_distance"
:
20.0
,
"loss_unit_distance"
:
20.0
,
"weight"
:
0.5
,
"weight"
:
0.5
,
...
@@ -918,5 +903,5 @@ multimer_model_config_update = {
...
@@ -918,5 +903,5 @@ multimer_model_config_update = {
"enabled"
:
True
,
"enabled"
:
True
,
},
},
"eps"
:
eps
,
"eps"
:
eps
,
}
,
}
}
}
openfold/data/data_modules.py
View file @
30764cf9
...
@@ -7,7 +7,6 @@ import pickle
...
@@ -7,7 +7,6 @@ import pickle
from
typing
import
Optional
,
Sequence
,
List
,
Any
from
typing
import
Optional
,
Sequence
,
List
,
Any
import
ml_collections
as
mlc
import
ml_collections
as
mlc
import
numpy
as
np
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch
import
torch
from
torch.utils.data
import
RandomSampler
from
torch.utils.data
import
RandomSampler
...
@@ -18,7 +17,7 @@ from openfold.data import (
...
@@ -18,7 +17,7 @@ from openfold.data import (
mmcif_parsing
,
mmcif_parsing
,
templates
,
templates
,
)
)
from
openfold.utils.tensor_utils
import
tensor_tree_map
,
dict_multimap
from
openfold.utils.tensor_utils
import
dict_multimap
import
contextlib
import
contextlib
import
tempfile
import
tempfile
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
...
@@ -34,6 +33,7 @@ def temp_fasta_file(sequence_str):
...
@@ -34,6 +33,7 @@ def temp_fasta_file(sequence_str):
fasta_file
.
seek
(
0
)
fasta_file
.
seek
(
0
)
yield
fasta_file
.
name
yield
fasta_file
.
name
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
def
__init__
(
self
,
data_dir
:
str
,
data_dir
:
str
,
...
@@ -296,6 +296,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -296,6 +296,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
_chain_ids
)
return
len
(
self
.
_chain_ids
)
class
OpenFoldSingleMultimerDataset
(
torch
.
utils
.
data
.
Dataset
):
class
OpenFoldSingleMultimerDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
def
__init__
(
self
,
data_dir
:
str
,
data_dir
:
str
,
...
@@ -549,10 +550,10 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
...
@@ -549,10 +550,10 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
device
=
all_chain_features
[
"aatype"
].
device
)
device
=
all_chain_features
[
"aatype"
].
device
)
return
all_chain_features
return
all_chain_features
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
_chain_ids
)
return
len
(
self
.
_chain_ids
)
def
deterministic_train_filter
(
def
deterministic_train_filter
(
chain_data_cache_entry
:
Any
,
chain_data_cache_entry
:
Any
,
max_resolution
:
float
=
9.
,
max_resolution
:
float
=
9.
,
...
@@ -575,6 +576,7 @@ def deterministic_train_filter(
...
@@ -575,6 +576,7 @@ def deterministic_train_filter(
return
True
return
True
def
deterministic_multimer_train_filter
(
def
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
,
mmcif_data_cache_entry
,
max_resolution
:
float
=
9.
,
max_resolution
:
float
=
9.
,
...
@@ -613,9 +615,10 @@ def deterministic_multimer_train_filter(
...
@@ -613,9 +615,10 @@ def deterministic_multimer_train_filter(
return
True
return
True
def
get_stochastic_train_filter_prob
(
def
get_stochastic_train_filter_prob
(
chain_data_cache_entry
:
Any
,
chain_data_cache_entry
:
Any
,
)
->
List
[
float
]
:
)
->
float
:
# Stochastic filters
# Stochastic filters
probabilities
=
[]
probabilities
=
[]
...
@@ -723,8 +726,8 @@ class OpenFoldDataset(torch.utils.data.Dataset):
...
@@ -723,8 +726,8 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datapoint_idx
=
next
(
samples
)
datapoint_idx
=
next
(
samples
)
self
.
datapoints
.
append
((
dataset_idx
,
datapoint_idx
))
self
.
datapoints
.
append
((
dataset_idx
,
datapoint_idx
))
class
OpenFoldMultimerDataset
(
torch
.
utils
.
data
.
Dataset
):
class
OpenFoldMultimerDataset
(
torch
.
utils
.
data
.
Dataset
):
"""
"""
Create a torch Dataset object for multimer training and
Create a torch Dataset object for multimer training and
add filtering steps described in AlphaFold Multimer's paper:
add filtering steps described in AlphaFold Multimer's paper:
...
@@ -753,7 +756,8 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
...
@@ -753,7 +756,8 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
chains
=
mmcif_data_cache
[
mmcif_id
][
'chain_ids'
]
chains
=
mmcif_data_cache
[
mmcif_id
][
'chain_ids'
]
mmcif_data_cache_entry
=
mmcif_data_cache
[
mmcif_id
]
mmcif_data_cache_entry
=
mmcif_data_cache
[
mmcif_id
]
if
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
,
if
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
,
max_resolution
=
9
,
minimum_number_of_residues
=
5
):
max_resolution
=
9
,
minimum_number_of_residues
=
5
):
selected_idx
.
append
(
i
)
selected_idx
.
append
(
i
)
return
selected_idx
return
selected_idx
...
@@ -781,11 +785,13 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
...
@@ -781,11 +785,13 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
logging
.
info
(
f
"self.epoch_len is
{
self
.
epoch_len
}
"
)
logging
.
info
(
f
"self.epoch_len is
{
self
.
epoch_len
}
"
)
self
.
datapoints
+=
[(
dataset_idx
,
selected_idx
[
i
])
for
i
in
range
(
self
.
epoch_len
)
]
self
.
datapoints
+=
[(
dataset_idx
,
selected_idx
[
i
])
for
i
in
range
(
self
.
epoch_len
)
]
class
OpenFoldBatchCollator
:
class
OpenFoldBatchCollator
:
def
__call__
(
self
,
prots
):
def
__call__
(
self
,
prots
):
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
return
dict_multimap
(
stack_fn
,
prots
)
return
dict_multimap
(
stack_fn
,
prots
)
class
OpenFoldDataLoader
(
torch
.
utils
.
data
.
DataLoader
):
class
OpenFoldDataLoader
(
torch
.
utils
.
data
.
DataLoader
):
def
__init__
(
self
,
*
args
,
config
,
stage
=
"train"
,
generator
=
None
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
config
,
stage
=
"train"
,
generator
=
None
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
...
@@ -873,6 +879,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
...
@@ -873,6 +879,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return
_batch_prop_gen
(
it
)
return
_batch_prop_gen
(
it
)
class
OpenFoldMultimerDataLoader
(
torch
.
utils
.
data
.
DataLoader
):
class
OpenFoldMultimerDataLoader
(
torch
.
utils
.
data
.
DataLoader
):
def
__init__
(
self
,
*
args
,
config
,
stage
=
"train"
,
generator
=
None
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
config
,
stage
=
"train"
,
generator
=
None
,
**
kwargs
):
super
(
OpenFoldMultimerDataLoader
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
(
OpenFoldMultimerDataLoader
,
self
).
__init__
(
*
args
,
**
kwargs
)
...
@@ -1110,7 +1117,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -1110,7 +1117,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
def
predict_dataloader
(
self
):
def
predict_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"predict"
)
return
self
.
_gen_dataloader
(
"predict"
)
class
OpenFoldMultimerDataModule
(
OpenFoldDataModule
):
class
OpenFoldMultimerDataModule
(
OpenFoldDataModule
):
"""
"""
Create a datamodule specifically for multimer training
Create a datamodule specifically for multimer training
...
...
openfold/data/data_pipeline.py
View file @
30764cf9
...
@@ -784,45 +784,6 @@ class DataPipeline:
...
@@ -784,45 +784,6 @@ class DataPipeline:
return
all_hits
return
all_hits
def
_parse_template_hits
(
self
,
alignment_dir
:
str
,
alignment_index
:
Optional
[
Any
]
=
None
,
input_sequence
=
None
,
)
->
Mapping
[
str
,
Any
]:
all_hits
=
{}
if
(
alignment_index
is
not
None
):
fp
=
open
(
os
.
path
.
join
(
alignment_dir
,
alignment_index
[
"db"
]),
'rb'
)
def
read_template
(
start
,
size
):
fp
.
seek
(
start
)
return
fp
.
read
(
size
).
decode
(
"utf-8"
)
for
(
name
,
start
,
size
)
in
alignment_index
[
"files"
]:
ext
=
os
.
path
.
splitext
(
name
)[
-
1
]
if
(
ext
==
".hhr"
):
hits
=
parsers
.
parse_hhr
(
read_template
(
start
,
size
))
all_hits
[
name
]
=
hits
fp
.
close
()
else
:
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
if
(
ext
==
".hhr"
):
with
open
(
path
,
"r"
)
as
fp
:
hits
=
parsers
.
parse_hhr
(
fp
.
read
())
all_hits
[
f
]
=
hits
elif
(
ext
==
'.sto'
)
and
(
f
.
startswith
(
"hmm"
)):
with
open
(
path
,
"r"
)
as
fp
:
hits
=
parsers
.
parse_hmmsearch_sto
(
fp
.
read
(),
input_sequence
)
all_hits
[
f
]
=
hits
fp
.
close
()
return
all_hits
def
_get_msas
(
self
,
def
_get_msas
(
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
input_sequence
:
Optional
[
str
]
=
None
,
input_sequence
:
Optional
[
str
]
=
None
,
...
@@ -879,9 +840,9 @@ class DataPipeline:
...
@@ -879,9 +840,9 @@ class DataPipeline:
num_res
=
len
(
input_sequence
)
num_res
=
len
(
input_sequence
)
hits
=
self
.
_parse_template_hit_files
(
hits
=
self
.
_parse_template_hit_files
(
alignment_dir
,
alignment_dir
=
alignment_dir
,
input_sequence
,
input_sequence
=
input_sequence
,
alignment_index
,
alignment_index
=
alignment_index
,
)
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
...
@@ -928,8 +889,9 @@ class DataPipeline:
...
@@ -928,8 +889,9 @@ class DataPipeline:
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
hits
=
self
.
_parse_template_hit_files
(
hits
=
self
.
_parse_template_hit_files
(
alignment_dir
,
alignment_dir
=
alignment_dir
,
alignment_index
,
input_sequence
)
input_sequence
=
input_sequence
,
alignment_index
=
alignment_index
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
...
@@ -976,8 +938,9 @@ class DataPipeline:
...
@@ -976,8 +938,9 @@ class DataPipeline:
)
)
hits
=
self
.
_parse_template_hit_files
(
hits
=
self
.
_parse_template_hit_files
(
alignment_dir
,
alignment_dir
=
alignment_dir
,
alignment_index
,
input_sequence
input_sequence
=
input_sequence
,
alignment_index
=
alignment_index
,
)
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
...
@@ -1008,8 +971,9 @@ class DataPipeline:
...
@@ -1008,8 +971,9 @@ class DataPipeline:
core_feats
=
make_protein_features
(
protein_object
,
description
)
core_feats
=
make_protein_features
(
protein_object
,
description
)
hits
=
self
.
_parse_template_hit_files
(
hits
=
self
.
_parse_template_hit_files
(
alignment_dir
,
alignment_dir
=
alignment_dir
,
alignment_index
,
input_sequence
input_sequence
=
input_sequence
,
alignment_index
=
alignment_index
,
)
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
...
@@ -1098,7 +1062,10 @@ class DataPipeline:
...
@@ -1098,7 +1062,10 @@ class DataPipeline:
alignment_dir
=
os
.
path
.
join
(
alignment_dir
=
os
.
path
.
join
(
super_alignment_dir
,
desc
super_alignment_dir
,
desc
)
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
alignment_index
=
None
,
input_sequence
=
input_sequence
)
hits
=
self
.
_parse_template_hit_files
(
alignment_dir
=
alignment_dir
,
input_sequence
=
seq
,
alignment_index
=
None
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
seq
,
seq
,
hits
,
hits
,
...
...
openfold/utils/loss.py
View file @
30764cf9
...
@@ -310,10 +310,10 @@ def fape_loss(
...
@@ -310,10 +310,10 @@ def fape_loss(
interface_bb_loss
=
backbone_loss
(
interface_bb_loss
=
backbone_loss
(
traj
=
traj
,
traj
=
traj
,
pair_mask
=
1.
-
intra_chain_mask
,
pair_mask
=
1.
-
intra_chain_mask
,
**
{
**
batch
,
**
config
.
interface
},
**
{
**
batch
,
**
config
.
interface
_backbone
},
)
)
weighted_bb_loss
=
(
intra_chain_bb_loss
*
config
.
intra_chain_backbone
.
weight
weighted_bb_loss
=
(
intra_chain_bb_loss
*
config
.
intra_chain_backbone
.
weight
+
interface_bb_loss
*
config
.
interface
.
weight
)
+
interface_bb_loss
*
config
.
interface
_backbone
.
weight
)
else
:
else
:
bb_loss
=
backbone_loss
(
bb_loss
=
backbone_loss
(
traj
=
traj
,
traj
=
traj
,
...
@@ -541,8 +541,11 @@ def lddt_loss(
...
@@ -541,8 +541,11 @@ def lddt_loss(
cutoff
=
cutoff
,
cutoff
=
cutoff
,
eps
=
eps
eps
=
eps
)
)
score
=
torch
.
nan_to_num
(
score
,
nan
=
torch
.
nanmean
(
score
))
# TODO: Remove after initial pipeline testing
score
=
torch
.
nan_to_num
(
score
,
nan
=
torch
.
nanmean
(
score
))
score
[
score
<
0
]
=
0
score
[
score
<
0
]
=
0
score
=
score
.
detach
()
score
=
score
.
detach
()
bin_index
=
torch
.
floor
(
score
*
no_bins
).
long
()
bin_index
=
torch
.
floor
(
score
*
no_bins
).
long
()
bin_index
=
torch
.
clamp
(
bin_index
,
max
=
(
no_bins
-
1
))
bin_index
=
torch
.
clamp
(
bin_index
,
max
=
(
no_bins
-
1
))
...
@@ -1233,7 +1236,7 @@ def find_structural_violations(
...
@@ -1233,7 +1236,7 @@ def find_structural_violations(
batch
[
"atom14_atom_exists"
]
batch
[
"atom14_atom_exists"
]
*
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
*
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
)
)
torch
.
cuda
.
memory_summary
()
# Compute the between residue clash loss.
# Compute the between residue clash loss.
between_residue_clashes
=
between_residue_clash_loss
(
between_residue_clashes
=
between_residue_clash_loss
(
atom14_pred_positions
=
atom14_pred_positions
,
atom14_pred_positions
=
atom14_pred_positions
,
...
@@ -1665,9 +1668,11 @@ def chain_center_of_mass_loss(
...
@@ -1665,9 +1668,11 @@ def chain_center_of_mass_loss(
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
chains
,
_
=
asym_id
.
unique
(
return_counts
=
True
)
chains
=
asym_id
.
unique
()
one_hot
=
torch
.
nn
.
functional
.
one_hot
(
asym_id
.
to
(
torch
.
int64
)
-
1
,
# have to reduce asym_id by one because class values must be smaller than num_classes
num_classes
=
chains
.
shape
[
0
]).
to
(
dtype
=
all_atom_mask
.
dtype
)
# make sure asym_id dtype is int
# Reduce asym_id by one because class values must be smaller than num_classes and asym_ids start at 1
one_hot
=
torch
.
nn
.
functional
.
one_hot
(
asym_id
.
long
()
-
1
,
num_classes
=
chains
.
shape
[
0
]).
to
(
dtype
=
all_atom_mask
.
dtype
)
one_hot
=
one_hot
*
all_atom_mask
one_hot
=
one_hot
*
all_atom_mask
chain_pos_mask
=
one_hot
.
transpose
(
-
2
,
-
1
)
chain_pos_mask
=
one_hot
.
transpose
(
-
2
,
-
1
)
chain_exists
=
torch
.
any
(
chain_pos_mask
,
dim
=-
1
).
float
()
chain_exists
=
torch
.
any
(
chain_pos_mask
,
dim
=-
1
).
float
()
...
@@ -1688,6 +1693,7 @@ def chain_center_of_mass_loss(
...
@@ -1688,6 +1693,7 @@ def chain_center_of_mass_loss(
loss
=
masked_mean
(
loss_mask
,
losses
,
dim
=
(
-
1
,
-
2
))
loss
=
masked_mean
(
loss_mask
,
losses
,
dim
=
(
-
1
,
-
2
))
return
loss
return
loss
# #
# #
# below are the functions required for permutations
# below are the functions required for permutations
# #
# #
...
@@ -1715,6 +1721,7 @@ def kabsch_rotation(P, Q):
...
@@ -1715,6 +1721,7 @@ def kabsch_rotation(P, Q):
assert
rotation
.
shape
==
torch
.
Size
([
3
,
3
])
assert
rotation
.
shape
==
torch
.
Size
([
3
,
3
])
return
rotation
.
to
(
'cuda'
)
return
rotation
.
to
(
'cuda'
)
def
get_optimal_transform
(
def
get_optimal_transform
(
src_atoms
:
torch
.
Tensor
,
src_atoms
:
torch
.
Tensor
,
tgt_atoms
:
torch
.
Tensor
,
tgt_atoms
:
torch
.
Tensor
,
...
...
tests/compare_utils.py
View file @
30764cf9
...
@@ -51,7 +51,8 @@ def get_alphafold_config():
...
@@ -51,7 +51,8 @@ def get_alphafold_config():
return
config
return
config
_param_path
=
f
"openfold/resources/params/params_
{
consts
.
model
}
.npz"
dir_path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
_param_path
=
os
.
path
.
join
(
dir_path
,
".."
,
f
"openfold/resources/params/params_
{
consts
.
model
}
.npz"
)
_model
=
None
_model
=
None
...
...
tests/test_template.py
View file @
30764cf9
...
@@ -256,7 +256,7 @@ class Template(unittest.TestCase):
...
@@ -256,7 +256,7 @@ class Template(unittest.TestCase):
template_feats
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
template_feats
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
if
consts
.
is_multimer
:
if
consts
.
is_multimer
:
out_repro
=
model
.
template_embedder
(
out_repro
_all
=
model
.
template_embedder
(
template_feats
,
template_feats
,
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
...
@@ -267,7 +267,7 @@ class Template(unittest.TestCase):
...
@@ -267,7 +267,7 @@ class Template(unittest.TestCase):
inplace_safe
=
False
inplace_safe
=
False
)
)
else
:
else
:
out_repro
=
model
.
template_embedder
(
out_repro
_all
=
model
.
template_embedder
(
template_feats
,
template_feats
,
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
...
@@ -277,10 +277,10 @@ class Template(unittest.TestCase):
...
@@ -277,10 +277,10 @@ class Template(unittest.TestCase):
inplace_safe
=
False
inplace_safe
=
False
)
)
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
_all
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
)
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
)
<
consts
.
eps
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment