Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
39a6d0e6
Commit
39a6d0e6
authored
Apr 09, 2023
by
Christina Floristean
Browse files
Merging in main branch
parents
d8ee9c5f
84659c93
Changes
101
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
478 additions
and
130 deletions
+478
-130
scripts/utils.py
scripts/utils.py
+67
-1
scripts/zero_to_fp32.py
scripts/zero_to_fp32.py
+12
-0
setup.py
setup.py
+72
-45
tests/config.py
tests/config.py
+1
-0
tests/data_utils.py
tests/data_utils.py
+4
-1
tests/test_data_transforms.py
tests/test_data_transforms.py
+5
-2
tests/test_embedders.py
tests/test_embedders.py
+4
-4
tests/test_evoformer.py
tests/test_evoformer.py
+20
-5
tests/test_feats.py
tests/test_feats.py
+28
-14
tests/test_loss.py
tests/test_loss.py
+11
-4
tests/test_model.py
tests/test_model.py
+8
-4
tests/test_msa.py
tests/test_msa.py
+2
-4
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+1
-1
tests/test_primitives.py
tests/test_primitives.py
+9
-30
tests/test_structure_module.py
tests/test_structure_module.py
+7
-4
tests/test_template.py
tests/test_template.py
+10
-6
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+8
-1
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+35
-3
tests/test_utils.py
tests/test_utils.py
+1
-1
thread_sequence.py
thread_sequence.py
+173
-0
No files found.
scripts/utils.py
View file @
39a6d0e6
import
argparse
import
argparse
import
ctypes
from
datetime
import
date
from
datetime
import
date
import
sys
def
add_data_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_data_args
(
parser
:
argparse
.
ArgumentParser
):
...
@@ -43,7 +45,7 @@ def add_data_args(parser: argparse.ArgumentParser):
...
@@ -43,7 +45,7 @@ def add_data_args(parser: argparse.ArgumentParser):
'--kalign_binary_path'
,
type
=
str
,
default
=
'/usr/bin/kalign'
'--kalign_binary_path'
,
type
=
str
,
default
=
'/usr/bin/kalign'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--max_template_date'
,
type
=
str
,
'--max_template_date'
,
type
=
str
,
default
=
date
.
today
().
strftime
(
"%Y-%m-%d"
),
default
=
date
.
today
().
strftime
(
"%Y-%m-%d"
),
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -52,3 +54,67 @@ def add_data_args(parser: argparse.ArgumentParser):
...
@@ -52,3 +54,67 @@ def add_data_args(parser: argparse.ArgumentParser):
parser
.
add_argument
(
parser
.
add_argument
(
'--release_dates_path'
,
type
=
str
,
default
=
None
'--release_dates_path'
,
type
=
str
,
default
=
None
)
)
def
get_nvidia_cc
():
"""
Returns a tuple containing the Compute Capability of the first GPU
installed in the system (formatted as a tuple of strings) and an error
message. When the former is provided, the latter is None, and vice versa.
Adapted from script by Jan Schlüte t
https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
"""
CUDA_SUCCESS
=
0
libnames
=
[
'libcuda.so'
,
'libcuda.dylib'
,
'cuda.dll'
,
'/usr/local/cuda/compat/libcuda.so'
,
# For Docker
]
for
libname
in
libnames
:
try
:
cuda
=
ctypes
.
CDLL
(
libname
)
except
OSError
:
continue
else
:
break
else
:
return
None
,
"Could not load any of: "
+
' '
.
join
(
libnames
)
nGpus
=
ctypes
.
c_int
()
cc_major
=
ctypes
.
c_int
()
cc_minor
=
ctypes
.
c_int
()
result
=
ctypes
.
c_int
()
device
=
ctypes
.
c_int
()
error_str
=
ctypes
.
c_char_p
()
result
=
cuda
.
cuInit
(
0
)
if
result
!=
CUDA_SUCCESS
:
cuda
.
cuGetErrorString
(
result
,
ctypes
.
byref
(
error_str
))
if
error_str
.
value
:
return
None
,
error_str
.
value
.
decode
()
else
:
return
None
,
"Unknown error: cuInit returned %d"
%
result
result
=
cuda
.
cuDeviceGetCount
(
ctypes
.
byref
(
nGpus
))
if
result
!=
CUDA_SUCCESS
:
cuda
.
cuGetErrorString
(
result
,
ctypes
.
byref
(
error_str
))
return
None
,
error_str
.
value
.
decode
()
if
nGpus
.
value
<
1
:
return
None
,
"No GPUs detected"
result
=
cuda
.
cuDeviceGet
(
ctypes
.
byref
(
device
),
0
)
if
result
!=
CUDA_SUCCESS
:
cuda
.
cuGetErrorString
(
result
,
ctypes
.
byref
(
error_str
))
return
None
,
error_str
.
value
.
decode
()
if
cuda
.
cuDeviceComputeCapability
(
ctypes
.
byref
(
cc_major
),
ctypes
.
byref
(
cc_minor
),
device
)
!=
CUDA_SUCCESS
:
return
None
,
"Compute Capability not found"
major
=
cc_major
.
value
minor
=
cc_minor
.
value
return
(
major
,
minor
),
None
scripts/zero_to_fp32.py
View file @
39a6d0e6
...
@@ -13,6 +13,7 @@ import glob
...
@@ -13,6 +13,7 @@ import glob
import
math
import
math
import
os
import
os
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
re
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
# DeepSpeed data structures it has to be available in the current python environment.
...
@@ -431,6 +432,17 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
...
@@ -431,6 +432,17 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
return
model
return
model
def
get_global_step_from_zero_checkpoint
(
checkpoint_dir
):
global_step
=
-
1
latest_path
=
os
.
path
.
join
(
checkpoint_dir
,
'latest'
)
if
os
.
path
.
isfile
(
latest_path
):
with
open
(
latest_path
,
'r'
)
as
fd
:
tag
=
fd
.
read
().
strip
()
match
=
re
.
match
(
r
"global_step([0-9]+)"
,
tag
)
global_step
=
int
(
match
.
group
(
1
))
else
:
raise
ValueError
(
f
"Unable to find 'latest' file at
{
latest_path
}
"
)
return
global_step
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
setup.py
View file @
39a6d0e6
...
@@ -16,7 +16,9 @@ import os
...
@@ -16,7 +16,9 @@ import os
from
setuptools
import
setup
,
Extension
,
find_packages
from
setuptools
import
setup
,
Extension
,
find_packages
import
subprocess
import
subprocess
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
scripts.utils
import
get_nvidia_cc
version_dependent_macros
=
[
version_dependent_macros
=
[
...
@@ -26,48 +28,56 @@ version_dependent_macros = [
...
@@ -26,48 +28,56 @@ version_dependent_macros = [
]
]
extra_cuda_flags
=
[
extra_cuda_flags
=
[
'-std=c++14'
,
'-std=c++14'
,
'-maxrregcount=50'
,
'-maxrregcount=50'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
'--expt-extended-lambda'
]
]
def
get_cuda_bare_metal_version
(
cuda_dir
):
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
if
cuda_dir
==
None
:
output
=
raw_output
.
split
()
print
(
"CUDA is not found, cpu version is installed"
)
release_idx
=
output
.
index
(
"release"
)
+
1
return
None
,
-
1
,
0
release
=
output
[
release_idx
].
split
(
"."
)
else
:
bare_metal_major
=
release
[
0
]
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
bare_metal_minor
=
release
[
1
][
0
]
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
return
raw_output
,
bare_metal_major
,
bare_metal_minor
compute_capabilities
=
set
([
(
3
,
7
),
# K80, e.g.
(
5
,
2
),
# Titan X
(
6
,
1
),
# GeForce 1000-series
])
c
c_flag
=
[
'-gencode'
,
'arch=compute_70,code=sm_70'
]
c
ompute_capabilities
.
add
((
7
,
0
))
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
compute_capabilities
.
add
((
8
,
0
))
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
compute_capability
,
_
=
get_nvidia_cc
()
if
compute_capability
is
not
None
:
compute_capabilities
=
set
([
compute_capability
])
cc_flag
=
[]
for
major
,
minor
in
list
(
compute_capabilities
):
cc_flag
.
extend
([
'-gencode'
,
f
'arch=compute_
{
major
}{
minor
}
,code=sm_
{
major
}{
minor
}
'
,
])
extra_cuda_flags
+=
cc_flag
extra_cuda_flags
+=
cc_flag
cc_flag
=
[
'-gencode'
,
'arch=compute_70,code=sm_70'
]
setup
(
if
bare_metal_major
!=
-
1
:
name
=
'openfold'
,
modules
=
[
CUDAExtension
(
version
=
'0.1.0'
,
description
=
'A PyTorch reimplementation of DeepMind
\'
s AlphaFold 2'
,
author
=
'Gustaf Ahdritz & DeepMind'
,
author_email
=
'gahdritz@gmail.com'
,
license
=
'Apache License, Version 2.0'
,
url
=
'https://github.com/aqlaboratory/openfold'
,
packages
=
find_packages
(
exclude
=
[
"tests"
,
"scripts"
]),
include_package_data
=
True
,
package_data
=
{
"openfold"
:
[
'utils/kernel/csrc/*'
],
""
:
[
"resources/stereo_chemical_props.txt"
]
},
ext_modules
=
[
CUDAExtension
(
name
=
"attn_core_inplace_cuda"
,
name
=
"attn_core_inplace_cuda"
,
sources
=
[
sources
=
[
"openfold/utils/kernel/csrc/softmax_cuda.cpp"
,
"openfold/utils/kernel/csrc/softmax_cuda.cpp"
,
...
@@ -75,34 +85,51 @@ setup(
...
@@ -75,34 +85,51 @@ setup(
],
],
include_dirs
=
[
include_dirs
=
[
os
.
path
.
join
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'openfold/utils/kernel/csrc/'
'openfold/utils/kernel/csrc/'
)
)
],
],
extra_compile_args
=
{
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'nvcc'
:
(
'nvcc'
:
(
[
'-O3'
,
'--use_fast_math'
]
+
[
'-O3'
,
'--use_fast_math'
]
+
version_dependent_macros
+
version_dependent_macros
+
extra_cuda_flags
extra_cuda_flags
),
),
}
}
)],
)]
else
:
modules
=
[
CppExtension
(
name
=
"attn_core_inplace_cuda"
,
sources
=
[
"openfold/utils/kernel/csrc/softmax_cuda.cpp"
,
"openfold/utils/kernel/csrc/softmax_cuda_stub.cpp"
,
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
],
}
)]
setup
(
name
=
'openfold'
,
version
=
'1.0.1'
,
description
=
'A PyTorch reimplementation of DeepMind
\'
s AlphaFold 2'
,
author
=
'Gustaf Ahdritz & DeepMind'
,
author_email
=
'gahdritz@gmail.com'
,
license
=
'Apache License, Version 2.0'
,
url
=
'https://github.com/aqlaboratory/openfold'
,
packages
=
find_packages
(
exclude
=
[
"tests"
,
"scripts"
]),
include_package_data
=
True
,
package_data
=
{
"openfold"
:
[
'utils/kernel/csrc/*'
],
""
:
[
"resources/stereo_chemical_props.txt"
]
},
ext_modules
=
modules
,
cmdclass
=
{
'build_ext'
:
BuildExtension
},
cmdclass
=
{
'build_ext'
:
BuildExtension
},
install_requires
=
[
'torch'
,
'deepspeed'
,
'biopython'
,
'ml-collections'
,
'numpy'
,
'scipy'
,
'pytorch_lightning'
,
'dm-tree'
,
],
classifiers
=
[
classifiers
=
[
'License :: OSI Approved :: Apache Software License'
,
'License :: OSI Approved :: Apache Software License'
,
'Operating System :: POSIX :: Linux'
,
'Operating System :: POSIX :: Linux'
,
'Programming Language :: Python :: 3.7,'
'Programming Language :: Python :: 3.7,'
'Topic :: Scientific/Engineering :: Artificial Intelligence'
,
'Topic :: Scientific/Engineering :: Artificial Intelligence'
,
],
],
)
)
tests/config.py
View file @
39a6d0e6
...
@@ -10,6 +10,7 @@ consts = mlc.ConfigDict(
...
@@ -10,6 +10,7 @@ consts = mlc.ConfigDict(
"n_seq"
:
13
,
"n_seq"
:
13
,
"n_templ"
:
3
,
"n_templ"
:
3
,
"n_extra"
:
17
,
"n_extra"
:
17
,
"n_heads_extra_msa"
:
8
,
"eps"
:
5e-4
,
"eps"
:
5e-4
,
# For compatibility with DeepMind's pretrained weights, it's easiest for
# For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values.
# everyone if these take their real values.
...
...
tests/data_utils.py
View file @
39a6d0e6
...
@@ -30,7 +30,10 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
...
@@ -30,7 +30,10 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
pieces
=
[]
pieces
=
[]
asym_ids
=
[]
asym_ids
=
[]
for
idx
in
range
(
n_chain
-
1
):
for
idx
in
range
(
n_chain
-
1
):
piece
=
randint
(
min_chain_len
,
(
n_res
-
sum
(
pieces
)
-
n_chain
+
idx
-
min_chain_len
))
n_stop
=
(
n_res
-
sum
(
pieces
)
-
n_chain
+
idx
-
min_chain_len
)
if
n_stop
<=
min_chain_len
:
break
piece
=
randint
(
min_chain_len
,
n_stop
)
pieces
.
append
(
piece
)
pieces
.
append
(
piece
)
asym_ids
.
extend
(
piece
*
[
idx
])
asym_ids
.
extend
(
piece
*
[
idx
])
asym_ids
.
extend
((
n_res
-
sum
(
pieces
))
*
[
n_chain
-
1
])
asym_ids
.
extend
((
n_res
-
sum
(
pieces
))
*
[
n_chain
-
1
])
...
...
tests/test_data_transforms.py
View file @
39a6d0e6
...
@@ -45,7 +45,7 @@ class TestDataTransforms(unittest.TestCase):
...
@@ -45,7 +45,7 @@ class TestDataTransforms(unittest.TestCase):
template_seq_one_hot
=
torch
.
FloatTensor
(
template_seq
.
shape
[
0
],
20
).
zero_
()
template_seq_one_hot
=
torch
.
FloatTensor
(
template_seq
.
shape
[
0
],
20
).
zero_
()
template_seq_one_hot
.
scatter_
(
1
,
template_seq
,
1
)
template_seq_one_hot
.
scatter_
(
1
,
template_seq
,
1
)
template_aatype
=
template_seq_one_hot
.
clone
().
detach
().
unsqueeze
(
0
)
template_aatype
=
template_seq_one_hot
.
clone
().
detach
().
unsqueeze
(
0
)
protein
=
{
'template_aatype'
:
template_aatype
}
protein
=
{
'template_aatype'
:
template_aatype
,
'aatype'
:
template_aatype
}
protein
=
fix_templates_aatype
(
protein
)
protein
=
fix_templates_aatype
(
protein
)
template_seq_ours
=
torch
.
tensor
([[
0
,
4
,
3
,
6
,
13
,
7
,
8
,
9
,
11
,
10
,
12
,
2
,
14
,
5
,
1
,
15
,
16
,
19
,
17
,
18
]
*
2
])
template_seq_ours
=
torch
.
tensor
([[
0
,
4
,
3
,
6
,
13
,
7
,
8
,
9
,
11
,
10
,
12
,
2
,
14
,
5
,
1
,
15
,
16
,
19
,
17
,
18
]
*
2
])
assert
torch
.
all
(
torch
.
eq
(
protein
[
'template_aatype'
],
template_seq_ours
))
assert
torch
.
all
(
torch
.
eq
(
protein
[
'template_aatype'
],
template_seq_ours
))
...
@@ -171,7 +171,10 @@ class TestDataTransforms(unittest.TestCase):
...
@@ -171,7 +171,10 @@ class TestDataTransforms(unittest.TestCase):
with
open
(
'tests/test_data/features.pkl'
,
'rb'
)
as
file
:
with
open
(
'tests/test_data/features.pkl'
,
'rb'
)
as
file
:
features
=
pickle
.
load
(
file
)
features
=
pickle
.
load
(
file
)
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'msa'
],
dtype
=
torch
.
int64
)}
protein
=
{
'msa'
:
torch
.
tensor
(
features
[
'msa'
],
dtype
=
torch
.
int64
),
'aatype'
:
torch
.
tensor
(
features
[
'aatype'
],
dtype
=
torch
.
int64
),
}
protein
=
make_hhblits_profile
(
protein
)
protein
=
make_hhblits_profile
(
protein
)
masked_msa_config
=
config
.
data
.
common
.
masked_msa
masked_msa_config
=
config
.
data
.
common
.
masked_msa
protein
=
make_masked_msa
.
__wrapped__
(
protein
,
masked_msa_config
,
replace_fraction
=
0.15
,
seed
=
42
)
protein
=
make_masked_msa
.
__wrapped__
(
protein
,
masked_msa_config
,
replace_fraction
=
0.15
,
seed
=
42
)
...
...
tests/test_embedders.py
View file @
39a6d0e6
...
@@ -50,18 +50,18 @@ class TestInputEmbedder(unittest.TestCase):
...
@@ -50,18 +50,18 @@ class TestInputEmbedder(unittest.TestCase):
entity_id
=
asym_id
entity_id
=
asym_id
sym_id
=
torch
.
zeros_like
(
entity_id
)
sym_id
=
torch
.
zeros_like
(
entity_id
)
batch
=
{
"target_feat"
:
tf
,
"residue_index"
:
ri
,
"msa_feat"
:
msa
}
if
consts
.
is_multimer
:
if
consts
.
is_multimer
:
ie
=
InputEmbedderMultimer
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
ie
=
InputEmbedderMultimer
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
max_relative_idx
=
max_relative_idx
,
max_relative_idx
=
max_relative_idx
,
use_chain_relative
=
use_chain_relative
,
use_chain_relative
=
use_chain_relative
,
max_relative_chain
=
max_relative_chain
)
max_relative_chain
=
max_relative_chain
)
batch
.
update
({
"asym_id"
:
asym_id
,
"entity_id"
:
entity_id
,
"sym_id"
:
sym_id
})
batch
=
{
"target_feat"
:
tf
,
"residue_index"
:
ri
,
"msa_feat"
:
msa
,
"asym_id"
:
asym_id
,
"entity_id"
:
entity_id
,
"sym_id"
:
sym_id
}
msa_emb
,
pair_emb
=
ie
(
batch
)
else
:
else
:
ie
=
InputEmbedder
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
relpos_k
)
ie
=
InputEmbedder
(
tf_dim
,
msa_dim
,
c_z
,
c_m
,
relpos_k
)
msa_emb
,
pair_emb
=
ie
(
tf
=
tf
,
ri
=
ri
,
msa
=
msa
,
inplace_safe
=
False
)
msa_emb
,
pair_emb
=
ie
(
batch
)
self
.
assertTrue
(
msa_emb
.
shape
==
(
b
,
n_clust
,
n_res
,
c_m
))
self
.
assertTrue
(
msa_emb
.
shape
==
(
b
,
n_clust
,
n_res
,
c_m
))
self
.
assertTrue
(
pair_emb
.
shape
==
(
b
,
n_res
,
n_res
,
c_z
))
self
.
assertTrue
(
pair_emb
.
shape
==
(
b
,
n_res
,
n_res
,
c_z
))
...
...
tests/test_evoformer.py
View file @
39a6d0e6
...
@@ -132,13 +132,31 @@ class TestEvoformerStack(unittest.TestCase):
...
@@ -132,13 +132,31 @@ class TestEvoformerStack(unittest.TestCase):
torch
.
as_tensor
(
masks
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"pair"
]).
cuda
(),
chunk_size
=
4
,
chunk_size
=
4
,
_mask_trans
=
False
,
_mask_trans
=
False
,
inplace_safe
=
False
,
)
)
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
))
<
consts
.
eps
)
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
))
<
consts
.
eps
)
# Inplace version
out_repro_msa
,
out_repro_pair
=
model
.
evoformer
.
blocks
[
0
](
torch
.
as_tensor
(
activations
[
"msa"
]).
cuda
(),
torch
.
as_tensor
(
activations
[
"pair"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"msa"
]).
cuda
(),
torch
.
as_tensor
(
masks
[
"pair"
]).
cuda
(),
chunk_size
=
4
,
_mask_trans
=
False
,
inplace_safe
=
True
,
)
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
))
<
consts
.
eps
)
class
TestExtraMSAStack
(
unittest
.
TestCase
):
class
TestExtraMSAStack
(
unittest
.
TestCase
):
...
@@ -270,9 +288,6 @@ class TestMSATransition(unittest.TestCase):
...
@@ -270,9 +288,6 @@ class TestMSATransition(unittest.TestCase):
.
cpu
()
.
cpu
()
)
)
print
(
out_gt
)
print
(
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
...
...
tests/test_feats.py
View file @
39a6d0e6
...
@@ -34,7 +34,7 @@ from openfold.utils.tensor_utils import (
...
@@ -34,7 +34,7 @@ from openfold.utils.tensor_utils import (
)
)
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
random_affines_4x4
from
tests.data_utils
import
random_affines_4x4
,
random_asym_ids
if
compare_utils
.
alphafold_is_installed
():
if
compare_utils
.
alphafold_is_installed
():
alphafold
=
compare_utils
.
import_alphafold
()
alphafold
=
compare_utils
.
import_alphafold
()
...
@@ -170,14 +170,21 @@ class TestFeats(unittest.TestCase):
...
@@ -170,14 +170,21 @@ class TestFeats(unittest.TestCase):
out_gt
=
f
.
apply
({},
None
,
**
batch
)
out_gt
=
f
.
apply
({},
None
,
**
batch
)
if
consts
.
is_multimer
:
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
to_tensor
=
(
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
to_tensor
=
(
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
if
not
isinstance
(
t
,
self
.
am_rigid
.
Rigid3Array
)
if
not
isinstance
(
t
,
self
.
am_rigid
.
Rigid3Array
)
else
torch
.
tensor
(
np
.
array
(
t
.
to_array
()))
.
view
(
*
t
.
shape
[:
2
],
12
)
)
else
torch
.
tensor
(
np
.
array
(
t
.
to_array
())))
else
:
else
:
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
))
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
out_gt
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
out_gt
.
items
()}
def
rigid3x4_to_4x4
(
rigid3arr
):
four_by_four
=
torch
.
zeros
(
*
rigid3arr
.
shape
[:
-
2
],
4
,
4
)
four_by_four
[...,
:
3
,
:
4
]
=
rigid3arr
four_by_four
[...,
3
,
3
]
=
1
return
four_by_four
def
flat12_to_4x4
(
flat12
):
def
flat12_to_4x4
(
flat12
):
rot
=
flat12
[...,
:
9
].
view
(
*
flat12
.
shape
[:
-
1
],
3
,
3
)
rot
=
flat12
[...,
:
9
].
view
(
*
flat12
.
shape
[:
-
1
],
3
,
3
)
trans
=
flat12
[...,
9
:]
trans
=
flat12
[...,
9
:]
...
@@ -189,10 +196,12 @@ class TestFeats(unittest.TestCase):
...
@@ -189,10 +196,12 @@ class TestFeats(unittest.TestCase):
return
four_by_four
return
four_by_four
out_gt
[
"rigidgroups_gt_frames"
]
=
flat12_to_4x4
(
convert_func
=
rigid3x4_to_4x4
if
consts
.
is_multimer
else
flat12_to_4x4
out_gt
[
"rigidgroups_gt_frames"
]
=
convert_func
(
out_gt
[
"rigidgroups_gt_frames"
]
out_gt
[
"rigidgroups_gt_frames"
]
)
)
out_gt
[
"rigidgroups_alt_gt_frames"
]
=
flat12_to_4x4
(
out_gt
[
"rigidgroups_alt_gt_frames"
]
=
convert_func
(
out_gt
[
"rigidgroups_alt_gt_frames"
]
out_gt
[
"rigidgroups_alt_gt_frames"
]
)
)
...
@@ -278,13 +287,21 @@ class TestFeats(unittest.TestCase):
...
@@ -278,13 +287,21 @@ class TestFeats(unittest.TestCase):
)
)
# Convert the Rigids to 4x4 transformation tensors
# Convert the Rigids to 4x4 transformation tensors
rots_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
rot
))
out_gt_rot
=
out_gt
.
rot
if
not
consts
.
is_multimer
else
out_gt
.
rotation
.
to_array
()
trans_gt
=
list
(
out_gt_trans
=
out_gt
.
trans
if
not
consts
.
is_multimer
else
out_gt
.
translation
.
to_array
()
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt
.
trans
)
)
if
consts
.
is_multimer
:
rots_gt
=
torch
.
cat
([
x
.
unsqueeze
(
-
1
)
for
x
in
rots_gt
],
dim
=-
1
)
rots_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt_rot
))
rots_gt
=
rots_gt
.
view
(
*
rots_gt
.
shape
[:
-
1
],
3
,
3
)
trans_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt_trans
))
trans_gt
=
torch
.
cat
([
x
.
unsqueeze
(
-
1
)
for
x
in
trans_gt
],
dim
=-
1
)
else
:
rots_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt_rot
))
trans_gt
=
list
(
map
(
lambda
x
:
torch
.
as_tensor
(
np
.
array
(
x
)),
out_gt_trans
)
)
rots_gt
=
torch
.
cat
([
x
.
unsqueeze
(
-
1
)
for
x
in
rots_gt
],
dim
=-
1
)
rots_gt
=
rots_gt
.
view
(
*
rots_gt
.
shape
[:
-
1
],
3
,
3
)
trans_gt
=
torch
.
cat
([
x
.
unsqueeze
(
-
1
)
for
x
in
trans_gt
],
dim
=-
1
)
transforms_gt
=
torch
.
cat
([
rots_gt
,
trans_gt
.
unsqueeze
(
-
1
)],
dim
=-
1
)
transforms_gt
=
torch
.
cat
([
rots_gt
,
trans_gt
.
unsqueeze
(
-
1
)],
dim
=-
1
)
bottom_row
=
torch
.
zeros
((
*
rots_gt
.
shape
[:
-
2
],
1
,
4
))
bottom_row
=
torch
.
zeros
((
*
rots_gt
.
shape
[:
-
2
],
1
,
4
))
bottom_row
[...,
3
]
=
1
bottom_row
[...,
3
]
=
1
...
@@ -321,9 +338,6 @@ class TestFeats(unittest.TestCase):
...
@@ -321,9 +338,6 @@ class TestFeats(unittest.TestCase):
torch
.
tensor
(
restype_atom14_rigid_group_positions
),
torch
.
tensor
(
restype_atom14_rigid_group_positions
),
)
)
if
consts
.
is_multimer
:
xyz
=
xyz
.
to_tensor
()
self
.
assertTrue
(
xyz
.
shape
==
(
batch_size
,
n_res
,
14
,
3
))
self
.
assertTrue
(
xyz
.
shape
==
(
batch_size
,
n_res
,
14
,
3
))
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
...
...
tests/test_loss.py
View file @
39a6d0e6
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
os
import
os
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
pathlib
import
Path
import
unittest
import
unittest
import
ml_collections
as
mlc
import
ml_collections
as
mlc
...
@@ -301,7 +302,8 @@ class TestLoss(unittest.TestCase):
...
@@ -301,7 +302,8 @@ class TestLoss(unittest.TestCase):
def
test_find_structural_violations_compare
(
self
):
def
test_find_structural_violations_compare
(
self
):
def
run_fsv
(
batch
,
pos
,
config
):
def
run_fsv
(
batch
,
pos
,
config
):
cwd
=
os
.
getcwd
()
cwd
=
os
.
getcwd
()
os
.
chdir
(
"tests/test_data"
)
fpath
=
Path
(
__file__
).
parent
.
resolve
()
/
"test_data"
os
.
chdir
(
str
(
fpath
))
if
consts
.
is_multimer
:
if
consts
.
is_multimer
:
atom14_pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
pos
)
atom14_pred_pos
=
self
.
am_rigid
.
Vec3Array
.
from_array
(
pos
)
...
@@ -436,7 +438,7 @@ class TestLoss(unittest.TestCase):
...
@@ -436,7 +438,7 @@ class TestLoss(unittest.TestCase):
"true_msa"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,
n_seq
)),
"true_msa"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,
n_seq
)),
"bert_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_seq
)).
astype
(
"bert_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_seq
)).
astype
(
np
.
float32
np
.
float32
)
,
)
}
}
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
out_gt
=
f
.
apply
({},
None
,
value
,
batch
)[
"loss"
]
...
@@ -448,7 +450,9 @@ class TestLoss(unittest.TestCase):
...
@@ -448,7 +450,9 @@ class TestLoss(unittest.TestCase):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out_repro
=
masked_msa_loss
(
out_repro
=
masked_msa_loss
(
value
[
"logits"
],
value
[
"logits"
],
**
batch
,
batch
[
"true_msa"
],
batch
[
"bert_mask"
],
consts
.
msa_logits
)
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
...
@@ -903,6 +907,9 @@ class TestLoss(unittest.TestCase):
...
@@ -903,6 +907,9 @@ class TestLoss(unittest.TestCase):
),
),
}
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
def
_build_extra_feats_np
():
def
_build_extra_feats_np
():
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
=
data_transforms
.
make_atom14_masks
(
b
)
b
=
data_transforms
.
make_atom14_masks
(
b
)
...
@@ -943,7 +950,7 @@ class TestLoss(unittest.TestCase):
...
@@ -943,7 +950,7 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
@
unittest
.
skipIf
(
not
consts
.
is_multimer
and
"ptm"
not
in
consts
.
model
,
"Not enabled for non-ptm models."
)
@
unittest
.
skipIf
(
consts
.
is_multimer
or
"ptm"
not
in
consts
.
model
,
"Not enabled for non-ptm models."
)
def
test_tm_loss_compare
(
self
):
def
test_tm_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
...
...
tests/test_model.py
View file @
39a6d0e6
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
pathlib
import
Path
import
pickle
import
pickle
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -54,7 +55,7 @@ class TestModel(unittest.TestCase):
...
@@ -54,7 +55,7 @@ class TestModel(unittest.TestCase):
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
n_extra_seq
=
consts
.
n_extra
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
consts
.
model
)
c
=
model_config
(
consts
.
model
,
train
=
True
)
c
.
model
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
model
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
# deepspeed for this test
...
@@ -68,6 +69,7 @@ class TestModel(unittest.TestCase):
...
@@ -68,6 +69,7 @@ class TestModel(unittest.TestCase):
).
float
()
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
...
@@ -95,11 +97,14 @@ class TestModel(unittest.TestCase):
...
@@ -95,11 +97,14 @@ class TestModel(unittest.TestCase):
out
=
model
(
batch
)
out
=
model
(
batch
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
@
unittest
.
skipIf
(
consts
.
is_multimer
,
"Additional changes required for multimer."
)
def
test_compare
(
self
):
def
test_compare
(
self
):
#TODO: Fix test data for multimer MSA features
def
run_alphafold
(
batch
):
def
run_alphafold
(
batch
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
model
=
self
.
am_modules
.
AlphaFold
(
config
.
model
)
model
=
self
.
am_modules
.
AlphaFold
(
config
.
model
)
return
model
(
return
model
(
batch
=
batch
,
batch
=
batch
,
is_training
=
False
,
is_training
=
False
,
...
@@ -110,7 +115,8 @@ class TestModel(unittest.TestCase):
...
@@ -110,7 +115,8 @@ class TestModel(unittest.TestCase):
params
=
compare_utils
.
fetch_alphafold_module_weights
(
""
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
""
)
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
fpath
=
Path
(
__file__
).
parent
.
resolve
()
/
"test_data/sample_feats.pickle"
with
open
(
str
(
fpath
),
"rb"
)
as
fp
:
batch
=
pickle
.
load
(
fp
)
batch
=
pickle
.
load
(
fp
)
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
batch
)
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
batch
)
...
@@ -150,6 +156,4 @@ class TestModel(unittest.TestCase):
...
@@ -150,6 +156,4 @@ class TestModel(unittest.TestCase):
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
]
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
]
out_repro
=
out_repro
.
squeeze
(
0
)
out_repro
=
out_repro
.
squeeze
(
0
)
print
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
)))
print
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
1e-3
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
1e-3
)
tests/test_msa.py
View file @
39a6d0e6
...
@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
...
@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
)
)
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
)
)
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
)
)
<
consts
.
eps
)
class
TestMSAColumnAttention
(
unittest
.
TestCase
):
class
TestMSAColumnAttention
(
unittest
.
TestCase
):
...
@@ -158,9 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase):
...
@@ -158,9 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase):
)
)
).
cpu
()
).
cpu
()
print
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
)))
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
class
TestMSAColumnGlobalAttention
(
unittest
.
TestCase
):
class
TestMSAColumnGlobalAttention
(
unittest
.
TestCase
):
...
...
tests/test_outer_product_mean.py
View file @
39a6d0e6
...
@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase):
...
@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase):
# Even when correct, OPM has large, precision-related errors. It gets
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
# a special pass from consts.eps.
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
5e-4
)
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
)
<
5e-4
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/test_primitives.py
View file @
39a6d0e6
...
@@ -15,54 +15,33 @@
...
@@ -15,54 +15,33 @@
import
torch
import
torch
import
unittest
import
unittest
from
openfold.model.primitives
import
(
from
openfold.model.primitives
import
Attention
Attention
)
from
tests.config
import
consts
from
tests.config
import
consts
class
TestLMA
(
unittest
.
TestCase
):
class
TestLMA
(
unittest
.
TestCase
):
def
test_lma_vs_attention
(
self
):
def
test_lma_vs_attention
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
c_hidden
=
32
c_hidden
=
32
n
=
2
**
12
n
=
2
**
12
no_heads
=
4
no_heads
=
4
q
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
q
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
k
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
kv
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
v
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
bias
=
[
torch
.
rand
(
no_heads
,
1
,
n
)]
bias
=
[
torch
.
rand
(
no_heads
,
1
,
n
)]
bias
=
[
b
.
cuda
()
for
b
in
bias
]
bias
=
[
b
.
cuda
()
for
b
in
bias
]
gating_fill
=
torch
.
rand
(
c_hidden
*
no_heads
,
c_hidden
)
o_fill
=
torch
.
rand
(
c_hidden
,
c_hidden
*
no_heads
)
lma
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
a
=
Attention
(
a
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
).
cuda
()
with
torch
.
no_grad
():
for
n
,
p
in
lma
.
named_parameters
():
attrs
=
n
.
split
(
'.'
)
param
=
a
for
attr
in
attrs
:
param
=
getattr
(
param
,
attr
)
param
.
copy_
(
p
)
for
m
in
[
lma
,
a
]:
m
.
linear_g
.
weight
.
copy_
(
gating_fill
)
m
.
linear_o
.
weight
.
copy_
(
o_fill
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
l
=
lm
a
(
q
,
k
,
v
,
biases
=
bias
,
use_lma
=
True
,
q_chunk_size
=
1024
,
kv_chunk_size
=
4096
)
l
=
a
(
q
,
kv
,
biases
=
bias
,
use_lma
=
True
)
real
=
a
(
q
,
k
,
v
,
biases
=
bias
)
real
=
a
(
q
,
kv
,
biases
=
bias
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
tests/test_structure_module.py
View file @
39a6d0e6
...
@@ -99,7 +99,7 @@ class TestStructureModule(unittest.TestCase):
...
@@ -99,7 +99,7 @@ class TestStructureModule(unittest.TestCase):
z
=
torch
.
rand
((
batch_size
,
n
,
n
,
c_z
))
z
=
torch
.
rand
((
batch_size
,
n
,
n
,
c_z
))
f
=
torch
.
randint
(
low
=
0
,
high
=
21
,
size
=
(
batch_size
,
n
)).
long
()
f
=
torch
.
randint
(
low
=
0
,
high
=
21
,
size
=
(
batch_size
,
n
)).
long
()
out
=
sm
(
s
,
z
,
f
)
out
=
sm
(
{
"single"
:
s
,
"pair"
:
z
}
,
f
)
if
consts
.
is_multimer
:
if
consts
.
is_multimer
:
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
))
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
))
...
@@ -183,10 +183,13 @@ class TestStructureModule(unittest.TestCase):
...
@@ -183,10 +183,13 @@ class TestStructureModule(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
structure_module
(
out_repro
=
model
.
structure_module
(
torch
.
as_tensor
(
representations
[
"single"
]).
cuda
(),
{
torch
.
as_tensor
(
representations
[
"pair"
]).
cuda
(),
"single"
:
torch
.
as_tensor
(
representations
[
"single"
]).
cuda
(),
"pair"
:
torch
.
as_tensor
(
representations
[
"pair"
]).
cuda
(),
},
torch
.
as_tensor
(
batch
[
"aatype"
]).
cuda
(),
torch
.
as_tensor
(
batch
[
"aatype"
]).
cuda
(),
mask
=
torch
.
as_tensor
(
batch
[
"seq_mask"
]).
cuda
(),
mask
=
torch
.
as_tensor
(
batch
[
"seq_mask"
]).
cuda
(),
inplace_safe
=
False
,
)
)
out_repro
=
out_repro
[
"positions"
][
-
1
].
cpu
()
out_repro
=
out_repro
[
"positions"
][
-
1
].
cpu
()
...
@@ -286,7 +289,7 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -286,7 +289,7 @@ class TestInvariantPointAttention(unittest.TestCase):
if
consts
.
is_multimer
:
if
consts
.
is_multimer
:
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
()
torch
.
as_tensor
(
affines
).
float
()
.
cuda
()
)
)
sample_affine
=
rigids
sample_affine
=
rigids
else
:
else
:
...
...
tests/test_template.py
View file @
39a6d0e6
...
@@ -206,7 +206,7 @@ class Template(unittest.TestCase):
...
@@ -206,7 +206,7 @@ class Template(unittest.TestCase):
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
test_compare
(
self
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
,
mc_
mask_2d
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
te
=
self
.
am_modules
.
TemplateEmbedding
(
te
=
self
.
am_modules
.
TemplateEmbedding
(
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
embeddings_and_evoformer
.
template
,
...
@@ -214,7 +214,7 @@ class Template(unittest.TestCase):
...
@@ -214,7 +214,7 @@ class Template(unittest.TestCase):
)
)
if
consts
.
is_multimer
:
if
consts
.
is_multimer
:
act
=
te
(
pair
,
batch
,
mask_2d
,
multichain_mask_2d
=
m
ultichain
_mask_2d
,
is_training
=
False
)
act
=
te
(
pair
,
batch
,
mask_2d
,
multichain_mask_2d
=
m
c
_mask_2d
,
is_training
=
False
)
else
:
else
:
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
return
act
return
act
...
@@ -228,12 +228,12 @@ class Template(unittest.TestCase):
...
@@ -228,12 +228,12 @@ class Template(unittest.TestCase):
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
multichain_mask_2d
=
None
if
consts
.
is_multimer
:
if
consts
.
is_multimer
:
asym_id
=
batch
[
'asym_id'
][
0
]
asym_id
=
batch
[
'asym_id'
][
0
]
multichain_mask_2d
=
(
multichain_mask_2d
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
).
astype
(
np
.
float32
)
).
astype
(
np
.
float32
)
batch
[
"multichain_mask_2d"
]
=
multichain_mask_2d
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
# Fetch pretrained parameters (but only from one block)]
...
@@ -242,7 +242,7 @@ class Template(unittest.TestCase):
...
@@ -242,7 +242,7 @@ class Template(unittest.TestCase):
)
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
batch
,
pair_mask
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
batch
,
pair_mask
,
multichain_mask_2d
).
block_until_ready
()
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
...
@@ -259,7 +259,9 @@ class Template(unittest.TestCase):
...
@@ -259,7 +259,9 @@ class Template(unittest.TestCase):
torch
.
as_tensor
(
pair_mask
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
chunk_size
=
consts
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
multichain_mask_2d
=
torch
.
as_tensor
(
multichain_mask_2d
).
cuda
(),
use_lma
=
False
,
inplace_safe
=
False
)
)
else
:
else
:
out_repro
=
model
.
template_embedder
(
out_repro
=
model
.
template_embedder
(
...
@@ -267,7 +269,9 @@ class Template(unittest.TestCase):
...
@@ -267,7 +269,9 @@ class Template(unittest.TestCase):
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
chunk_size
=
consts
.
chunk_size
,
use_lma
=
False
,
inplace_safe
=
False
)
)
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro
[
"template_pair_embedding"
]
...
...
tests/test_triangular_attention.py
View file @
39a6d0e6
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
copy
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
...
@@ -89,13 +90,19 @@ class TestTriangularAttention(unittest.TestCase):
...
@@ -89,13 +90,19 @@ class TestTriangularAttention(unittest.TestCase):
if
starting
if
starting
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_att_end
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_att_end
)
)
# To save memory, the full model transposes inputs outside of the
# triangle attention module. We adjust the module here.
module
=
copy
.
deepcopy
(
module
)
module
.
starting
=
starting
out_repro
=
module
(
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
chunk_size
=
None
,
chunk_size
=
None
,
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
m
ax
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
m
ean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_att_end_compare
(
self
):
def
test_tri_att_end_compare
(
self
):
...
...
tests/test_triangular_multiplicative_update.py
View file @
39a6d0e6
...
@@ -30,12 +30,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -30,12 +30,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def
test_shape
(
self
):
def
test_shape
(
self
):
c_z
=
consts
.
c_z
c_z
=
consts
.
c_z
c
=
11
c
=
11
outgoing
=
True
tm
=
TriangleMultiplicationOutgoing
(
tm
=
TriangleMultiplicationOutgoing
(
c_z
,
c_z
,
c
,
c
,
outgoing
,
)
)
n_res
=
consts
.
c_z
n_res
=
consts
.
c_z
...
@@ -94,9 +92,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -94,9 +92,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
out_repro
=
module
(
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
inplace_safe
=
True
,
_inplace_chunk_size
=
4
,
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
m
ax
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
)
)
self
.
assertTrue
(
torch
.
m
ean
(
torch
.
abs
(
out_gt
-
out_repro
)
)
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_mul_out_compare
(
self
):
def
test_tri_mul_out_compare
(
self
):
...
@@ -106,6 +105,39 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -106,6 +105,39 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def
test_tri_mul_in_compare
(
self
):
def
test_tri_mul_in_compare
(
self
):
self
.
_tri_mul_compare
(
incoming
=
True
)
self
.
_tri_mul_compare
(
incoming
=
True
)
def
_tri_mul_inplace
(
self
,
incoming
=
False
):
n_res
=
consts
.
n_res
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
))
pair_mask
=
pair_mask
.
astype
(
np
.
float32
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_out
)
out_stock
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
inplace_safe
=
False
,
).
cpu
()
# This has to come second because inference mode is in-place
out_inplace
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
inplace_safe
=
True
,
_inplace_chunk_size
=
2
,
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_stock
-
out_inplace
))
<
consts
.
eps
)
def
test_tri_mul_out_inference
(
self
):
self
.
_tri_mul_inplace
()
def
test_tri_mul_in_inference
(
self
):
self
.
_tri_mul_inplace
(
incoming
=
True
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
tests/test_utils.py
View file @
39a6d0e6
...
@@ -23,7 +23,7 @@ from openfold.utils.rigid_utils import (
...
@@ -23,7 +23,7 @@ from openfold.utils.rigid_utils import (
quat_to_rot
,
quat_to_rot
,
rot_to_quat
,
rot_to_quat
,
)
)
from
openfold.utils.
tensor
_utils
import
chunk_layer
,
_chunk_slice
from
openfold.utils.
chunk
_utils
import
chunk_layer
,
_chunk_slice
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
...
...
thread_sequence.py
0 → 100644
View file @
39a6d0e6
import
argparse
import
os
import
logging
import
random
import
numpy
import
torch
from
openfold.config
import
model_config
from
openfold.data
import
feature_pipeline
from
openfold.data.data_pipeline
import
make_sequence_features_with_custom_template
from
openfold.np
import
protein
from
openfold.utils.script_utils
import
load_models_from_command_line
,
parse_fasta
,
run_model
,
prep_output
,
\
relax_protein
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
)
from
scripts.utils
import
add_data_args
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
torch_versions
=
torch
.
__version__
.
split
(
"."
)
torch_major_version
=
int
(
torch_versions
[
0
])
torch_minor_version
=
int
(
torch_versions
[
1
])
if
(
torch_major_version
>
1
or
(
torch_major_version
==
1
and
torch_minor_version
>=
12
)
):
# Gives a large speedup on Ampere-class GPUs
torch
.
set_float32_matmul_precision
(
"high"
)
torch
.
set_grad_enabled
(
False
)
def
main
(
args
):
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
config
=
model_config
(
args
.
config_preset
)
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
2
**
32
)
numpy
.
random
.
seed
(
random_seed
)
torch
.
manual_seed
(
random_seed
+
1
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
with
open
(
args
.
input_fasta
)
as
fasta_file
:
tags
,
sequences
=
parse_fasta
(
fasta_file
.
read
())
if
len
(
sequences
)
!=
1
:
raise
ValueError
(
"the threading script can only process a single sequence"
)
query_sequence
=
sequences
[
0
]
query_tag
=
tags
[
0
]
feature_dict
=
make_sequence_features_with_custom_template
(
query_sequence
,
args
.
input_mmcif
,
args
.
template_id
,
args
.
chain_id
,
args
.
kalign_binary_path
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
)
processed_feature_dict
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_device
)
for
k
,
v
in
processed_feature_dict
.
items
()
}
model_generator
=
load_models_from_command_line
(
config
,
args
.
model_device
,
args
.
openfold_checkpoint_path
,
args
.
jax_param_path
,
args
.
output_dir
)
output_name
=
f
'
{
query_tag
}
_
{
args
.
config_preset
}
'
for
model
,
output_directory
in
model_generator
:
out
=
run_model
(
model
,
processed_feature_dict
,
query_tag
,
args
.
output_dir
)
# Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict
=
tensor_tree_map
(
lambda
x
:
numpy
.
array
(
x
[...,
-
1
].
cpu
()),
processed_feature_dict
)
out
=
tensor_tree_map
(
lambda
x
:
numpy
.
array
(
x
.
cpu
()),
out
)
unrelaxed_protein
=
prep_output
(
out
,
processed_feature_dict
,
feature_dict
,
feature_processor
,
args
.
config_preset
,
200
,
# this is the ri_multimer_gap. There's no multimer sequences here, so it doesnt matter what its set to
args
.
subtract_plddt
)
unrelaxed_output_path
=
os
.
path
.
join
(
output_directory
,
f
'
{
output_name
}
_unrelaxed.pdb'
)
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
logger
.
info
(
f
"Output written to
{
unrelaxed_output_path
}
..."
)
logger
.
info
(
f
"Running relaxation on
{
unrelaxed_output_path
}
..."
)
relax_protein
(
config
,
args
.
model_device
,
unrelaxed_protein
,
output_directory
,
output_name
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"input_fasta"
,
type
=
str
,
help
=
"the path to a fasta file containing a single sequence to thread"
)
parser
.
add_argument
(
"input_mmcif"
,
type
=
str
,
help
=
"the path to an mmcif file to thread the sequence on to"
)
parser
.
add_argument
(
"--template_id"
,
type
=
str
,
help
=
"a PDB id or other identifier for the template"
)
parser
.
add_argument
(
"--chain_id"
,
type
=
str
,
help
=
"""The chain ID of the chain in the template to use"""
)
parser
.
add_argument
(
"--model_device"
,
type
=
str
,
default
=
"cpu"
,
help
=
"""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser
.
add_argument
(
"--config_preset"
,
type
=
str
,
default
=
"model_1"
,
help
=
"""Name of a model config preset defined in openfold/config.py"""
)
parser
.
add_argument
(
"--jax_param_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to JAX model parameters. If None, and openfold_checkpoint_path
is also None, parameters are selected automatically according to
the model name from openfold/resources/params"""
)
parser
.
add_argument
(
"--openfold_checkpoint_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to OpenFold checkpoint. Can be either a DeepSpeed
checkpoint directory or a .pt file"""
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
os
.
getcwd
(),
help
=
"""Name of the directory in which to output the prediction"""
,
)
parser
.
add_argument
(
"--subtract_plddt"
,
action
=
"store_true"
,
default
=
False
,
help
=
""""Whether to output (100 - pLDDT) in the B-factor column instead
of the pLDDT itself"""
)
parser
.
add_argument
(
"--data_random_seed"
,
type
=
str
,
default
=
None
)
add_data_args
(
parser
)
args
=
parser
.
parse_args
()
if
(
args
.
jax_param_path
is
None
and
args
.
openfold_checkpoint_path
is
None
):
args
.
jax_param_path
=
os
.
path
.
join
(
"openfold"
,
"resources"
,
"params"
,
"params_"
+
args
.
config_preset
+
".npz"
)
if
(
args
.
model_device
==
"cpu"
and
torch
.
cuda
.
is_available
()):
logging
.
warning
(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)
main
(
args
)
\ No newline at end of file
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment