Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
alphafold2_jax
Commits
9b18d6a9
Commit
9b18d6a9
authored
Dec 11, 2022
by
Augustin Zidek
Browse files
Release code for v2.3.0
PiperOrigin-RevId: 494507694
parent
4494af84
Changes
30
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
667 additions
and
430 deletions
+667
-430
README.md
README.md
+130
-112
afdb/README.md
afdb/README.md
+1
-1
alphafold/data/pipeline.py
alphafold/data/pipeline.py
+7
-7
alphafold/model/common_modules.py
alphafold/model/common_modules.py
+61
-0
alphafold/model/config.py
alphafold/model/config.py
+64
-24
alphafold/model/folding.py
alphafold/model/folding.py
+4
-4
alphafold/model/folding_multimer.py
alphafold/model/folding_multimer.py
+4
-4
alphafold/model/geometry/struct_of_array.py
alphafold/model/geometry/struct_of_array.py
+2
-2
alphafold/model/mapping.py
alphafold/model/mapping.py
+3
-3
alphafold/model/modules.py
alphafold/model/modules.py
+109
-23
alphafold/model/modules_multimer.py
alphafold/model/modules_multimer.py
+224
-172
alphafold/model/utils.py
alphafold/model/utils.py
+22
-0
alphafold/notebooks/notebook_utils.py
alphafold/notebooks/notebook_utils.py
+10
-35
alphafold/notebooks/notebook_utils_test.py
alphafold/notebooks/notebook_utils_test.py
+13
-19
alphafold/relax/relax.py
alphafold/relax/relax.py
+4
-4
alphafold/relax/relax_test.py
alphafold/relax/relax_test.py
+1
-1
alphafold/relax/utils.py
alphafold/relax/utils.py
+0
-11
docker/Dockerfile
docker/Dockerfile
+3
-3
docker/run_docker.py
docker/run_docker.py
+5
-5
docs/casp15_predictions.zip
docs/casp15_predictions.zip
+0
-0
No files found.
README.md
View file @
9b18d6a9
This diff is collapsed.
Click to expand it.
afdb/README.md
View file @
9b18d6a9
...
...
@@ -304,9 +304,9 @@ fractionPlddtVeryHigh | `FLOAT64` | Fraction of the residues in the predi
fractionPlddtVeryLow |
`FLOAT64`
| Fraction of the residues in the prediction with pLDDT less than 50
gene |
`STRING`
| The name of the gene if known, e.g. "COII"
geneSynonyms |
`ARRAY<STRING>`
| Additional synonyms for the gene
globalMetricValue |
`FLOAT64`
| The mean pLDDT of this prediction
isReferenceProteome |
`BOOL`
| Is this protein part of the reference proteome?
isReviewed |
`BOOL`
| Has this protein been reviewed, i.e. is it part of SwissProt?
globalMetricValue |
`FLOAT64`
| The mean pLDDT of this prediction
latestVersion |
`INT64`
| The latest AFDB version for this prediction
modelCreatedDate |
`DATE`
| The date of creation for this entry, e.g. "2022-06-01"
organismCommonNames |
`ARRAY<STRING>`
| List of common organism names
...
...
alphafold/data/pipeline.py
View file @
9b18d6a9
...
...
@@ -117,7 +117,7 @@ class DataPipeline:
uniref90_database_path
:
str
,
mgnify_database_path
:
str
,
bfd_database_path
:
Optional
[
str
],
uni
clust
30_database_path
:
Optional
[
str
],
uni
ref
30_database_path
:
Optional
[
str
],
small_bfd_database_path
:
Optional
[
str
],
template_searcher
:
TemplateSearcher
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
...
...
@@ -135,9 +135,9 @@ class DataPipeline:
binary_path
=
jackhmmer_binary_path
,
database_path
=
small_bfd_database_path
)
else
:
self
.
hhblits_bfd_uni
clust
_runner
=
hhblits
.
HHBlits
(
self
.
hhblits_bfd_uni
ref
_runner
=
hhblits
.
HHBlits
(
binary_path
=
hhblits_binary_path
,
databases
=
[
bfd_database_path
,
uni
clust
30_database_path
])
databases
=
[
bfd_database_path
,
uni
ref
30_database_path
])
self
.
jackhmmer_mgnify_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
mgnify_database_path
)
...
...
@@ -211,14 +211,14 @@ class DataPipeline:
use_precomputed_msas
=
self
.
use_precomputed_msas
)
bfd_msa
=
parsers
.
parse_stockholm
(
jackhmmer_small_bfd_result
[
'sto'
])
else
:
bfd_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'bfd_uni
clust
_hits.a3m'
)
hhblits_bfd_uni
clust
_result
=
run_msa_tool
(
msa_runner
=
self
.
hhblits_bfd_uni
clust
_runner
,
bfd_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'bfd_uni
ref
_hits.a3m'
)
hhblits_bfd_uni
ref
_result
=
run_msa_tool
(
msa_runner
=
self
.
hhblits_bfd_uni
ref
_runner
,
input_fasta_path
=
input_fasta_path
,
msa_out_path
=
bfd_out_path
,
msa_format
=
'a3m'
,
use_precomputed_msas
=
self
.
use_precomputed_msas
)
bfd_msa
=
parsers
.
parse_a3m
(
hhblits_bfd_uni
clust
_result
[
'a3m'
])
bfd_msa
=
parsers
.
parse_a3m
(
hhblits_bfd_uni
ref
_result
[
'a3m'
])
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
...
...
alphafold/model/common_modules.py
View file @
9b18d6a9
...
...
@@ -128,3 +128,64 @@ class Linear(hk.Module):
return
output
class
LayerNorm
(
hk
.
LayerNorm
):
"""LayerNorm module.
Equivalent to hk.LayerNorm but with different parameter shapes: they are
always vectors rather than possibly higher-rank tensors. This makes it easier
to change the layout whilst keep the model weight-compatible.
"""
def
__init__
(
self
,
axis
,
create_scale
:
bool
,
create_offset
:
bool
,
eps
:
float
=
1e-5
,
scale_init
=
None
,
offset_init
=
None
,
use_fast_variance
:
bool
=
False
,
name
=
None
,
param_axis
=
None
):
super
().
__init__
(
axis
=
axis
,
create_scale
=
False
,
create_offset
=
False
,
eps
=
eps
,
scale_init
=
None
,
offset_init
=
None
,
use_fast_variance
=
use_fast_variance
,
name
=
name
,
param_axis
=
param_axis
)
self
.
_temp_create_scale
=
create_scale
self
.
_temp_create_offset
=
create_offset
def
__call__
(
self
,
x
:
jnp
.
ndarray
)
->
jnp
.
ndarray
:
is_bf16
=
(
x
.
dtype
==
jnp
.
bfloat16
)
if
is_bf16
:
x
=
x
.
astype
(
jnp
.
float32
)
param_axis
=
self
.
param_axis
[
0
]
if
self
.
param_axis
else
-
1
param_shape
=
(
x
.
shape
[
param_axis
],)
param_broadcast_shape
=
[
1
]
*
x
.
ndim
param_broadcast_shape
[
param_axis
]
=
x
.
shape
[
param_axis
]
scale
=
None
offset
=
None
if
self
.
_temp_create_scale
:
scale
=
hk
.
get_parameter
(
'scale'
,
param_shape
,
x
.
dtype
,
init
=
self
.
scale_init
)
scale
=
scale
.
reshape
(
param_broadcast_shape
)
if
self
.
_temp_create_offset
:
offset
=
hk
.
get_parameter
(
'offset'
,
param_shape
,
x
.
dtype
,
init
=
self
.
offset_init
)
offset
=
offset
.
reshape
(
param_broadcast_shape
)
out
=
super
().
__call__
(
x
,
scale
=
scale
,
offset
=
offset
)
if
is_bf16
:
out
=
out
.
astype
(
jnp
.
bfloat16
)
return
out
\ No newline at end of file
alphafold/model/config.py
View file @
9b18d6a9
...
...
@@ -26,12 +26,12 @@ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def
model_config
(
name
:
str
)
->
ml_collections
.
ConfigDict
:
"""Get the ConfigDict of a CASP14 model."""
if
'multimer'
in
name
:
return
CONFIG_MULTIMER
if
name
not
in
CONFIG_DIFFS
:
raise
ValueError
(
f
'Invalid model name
{
name
}
.'
)
cfg
=
copy
.
deepcopy
(
CONFIG
)
if
'multimer'
in
name
:
cfg
=
copy
.
deepcopy
(
CONFIG_MULTIMER
)
else
:
cfg
=
copy
.
deepcopy
(
CONFIG
)
cfg
.
update_from_flattened_dict
(
CONFIG_DIFFS
[
name
])
return
cfg
...
...
@@ -52,11 +52,11 @@ MODEL_PRESETS = {
'model_5_ptm'
,
),
'multimer'
:
(
'model_1_multimer_v
2
'
,
'model_2_multimer_v
2
'
,
'model_3_multimer_v
2
'
,
'model_4_multimer_v
2
'
,
'model_5_multimer_v
2
'
,
'model_1_multimer_v
3
'
,
'model_2_multimer_v
3
'
,
'model_3_multimer_v
3
'
,
'model_4_multimer_v
3
'
,
'model_5_multimer_v
3
'
,
),
}
MODEL_PRESETS
[
'monomer_casp14'
]
=
MODEL_PRESETS
[
'monomer'
]
...
...
@@ -118,8 +118,32 @@ CONFIG_DIFFS = {
},
'model_5_ptm'
:
{
'model.heads.predicted_aligned_error.weight'
:
0.1
}
},
'model_1_multimer_v3'
:
{},
'model_2_multimer_v3'
:
{},
'model_3_multimer_v3'
:
{},
'model_4_multimer_v3'
:
{
'model.embeddings_and_evoformer.num_extra_msa'
:
1152
},
'model_5_multimer_v3'
:
{
'model.embeddings_and_evoformer.num_extra_msa'
:
1152
},
}
# Key differences between multimer v1/v2 and v3, mostly due to numerical
# optimisations in the TriangleMultiplication module.
common_updates
=
{
'model.embeddings_and_evoformer.num_msa'
:
252
,
'model.embeddings_and_evoformer.num_extra_msa'
:
1152
,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming.fuse_projection_weights'
:
False
,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing.fuse_projection_weights'
:
False
,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_incoming.fuse_projection_weights'
:
False
,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_outgoing.fuse_projection_weights'
:
False
,
}
CONFIG_DIFFS
.
update
(
{
f
'model_
{
i
}
_multimer'
:
common_updates
for
i
in
range
(
1
,
6
)})
CONFIG_DIFFS
.
update
(
{
f
'model_
{
i
}
_multimer_v2'
:
common_updates
for
i
in
range
(
1
,
6
)})
CONFIG
=
ml_collections
.
ConfigDict
({
'data'
:
{
...
...
@@ -260,14 +284,16 @@ CONFIG = ml_collections.ConfigDict({
'equation'
:
'ikc,jkc->ijc'
,
'num_intermediate_channel'
:
128
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
False
,
},
'triangle_multiplication_incoming'
:
{
'dropout_rate'
:
0.25
,
'equation'
:
'kjc,kic->ijc'
,
'num_intermediate_channel'
:
128
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
False
,
},
'pair_transition'
:
{
'dropout_rate'
:
0.0
,
...
...
@@ -328,14 +354,16 @@ CONFIG = ml_collections.ConfigDict({
'equation'
:
'ikc,jkc->ijc'
,
'num_intermediate_channel'
:
64
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
False
,
},
'triangle_multiplication_incoming'
:
{
'dropout_rate'
:
0.25
,
'equation'
:
'kjc,kic->ijc'
,
'num_intermediate_channel'
:
64
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
False
,
},
'pair_transition'
:
{
'dropout_rate'
:
0.0
,
...
...
@@ -354,7 +382,7 @@ CONFIG = ml_collections.ConfigDict({
'multimer_mode'
:
False
,
'subbatch_size'
:
4
,
'use_remat'
:
False
,
'zero_init'
:
True
'zero_init'
:
True
,
},
'heads'
:
{
'distogram'
:
{
...
...
@@ -483,27 +511,29 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
'gating'
:
True
,
'num_head'
:
4
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
},
'triangle_multiplication_incoming'
:
{
'dropout_rate'
:
0.25
,
'equation'
:
'kjc,kic->ijc'
,
'num_intermediate_channel'
:
128
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
True
,
},
'triangle_multiplication_outgoing'
:
{
'dropout_rate'
:
0.25
,
'equation'
:
'ikc,jkc->ijc'
,
'num_intermediate_channel'
:
128
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
True
,
}
},
'extra_msa_channel'
:
64
,
'extra_msa_stack_num_block'
:
4
,
'num_msa'
:
252
,
'num_extra_msa'
:
1152
,
'num_msa'
:
508
,
'num_extra_msa'
:
2048
,
'masked_msa'
:
{
'profile_prob'
:
0.1
,
'replace_fraction'
:
0.15
,
...
...
@@ -564,24 +594,28 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
'equation'
:
'kjc,kic->ijc'
,
'num_intermediate_channel'
:
64
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
True
,
},
'triangle_multiplication_outgoing'
:
{
'dropout_rate'
:
0.25
,
'equation'
:
'ikc,jkc->ijc'
,
'num_intermediate_channel'
:
64
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
True
,
}
}
},
},
'global_config'
:
{
'bfloat16'
:
True
,
'bfloat16_output'
:
False
,
'deterministic'
:
False
,
'multimer_mode'
:
True
,
'subbatch_size'
:
4
,
'use_remat'
:
False
,
'zero_init'
:
True
'zero_init'
:
True
,
},
'heads'
:
{
'distogram'
:
{
...
...
@@ -651,7 +685,13 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
}
},
'num_ensemble_eval'
:
1
,
'num_recycle'
:
3
,
'num_recycle'
:
20
,
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `num_recycle` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
'recycle_early_stop_tolerance'
:
0.5
,
'resample_msa_in_recycling'
:
True
}
})
alphafold/model/folding.py
View file @
9b18d6a9
...
...
@@ -331,7 +331,7 @@ class FoldIteration(hk.Module):
safe_key
,
*
sub_keys
=
safe_key
.
split
(
3
)
sub_keys
=
iter
(
sub_keys
)
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
...
...
@@ -353,7 +353,7 @@ class FoldIteration(hk.Module):
act
=
jax
.
nn
.
relu
(
act
)
act
+=
input_act
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
...
...
@@ -410,7 +410,7 @@ def generate_affines(representations, batch, config, global_config,
c
=
config
sequence_mask
=
batch
[
'seq_mask'
][:,
None
]
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
...
...
@@ -433,7 +433,7 @@ def generate_affines(representations, batch, config, global_config,
'affine'
:
affine
.
to_tensor
(),
}
act_2d
=
hk
.
LayerNorm
(
act_2d
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
...
...
alphafold/model/folding_multimer.py
View file @
9b18d6a9
...
...
@@ -427,7 +427,7 @@ class FoldIteration(hk.Module):
safe_key
,
*
sub_keys
=
safe_key
.
split
(
3
)
sub_keys
=
iter
(
sub_keys
)
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=-
1
,
create_scale
=
True
,
create_offset
=
True
,
...
...
@@ -448,7 +448,7 @@ class FoldIteration(hk.Module):
act
=
jax
.
nn
.
relu
(
act
)
act
+=
input_act
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=-
1
,
create_scale
=
True
,
create_offset
=
True
,
...
...
@@ -500,7 +500,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
"""
c
=
config
sequence_mask
=
batch
[
'seq_mask'
][:,
None
]
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=-
1
,
create_scale
=
True
,
create_offset
=
True
,
name
=
'single_layer_norm'
)(
representations
[
'single'
])
...
...
@@ -523,7 +523,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
rigid
}
act_2d
=
hk
.
LayerNorm
(
act_2d
=
common_modules
.
LayerNorm
(
axis
=-
1
,
create_scale
=
True
,
create_offset
=
True
,
...
...
alphafold/model/geometry/struct_of_array.py
View file @
9b18d6a9
...
...
@@ -133,7 +133,7 @@ def flatten(instance):
inner_treedefs
=
[]
num_arrays
=
[]
for
array_like
in
array_likes
:
flat_array_like
,
inner_treedef
=
jax
.
tree_flatten
(
array_like
)
flat_array_like
,
inner_treedef
=
jax
.
tree_
util
.
tree_
flatten
(
array_like
)
inner_treedefs
.
append
(
inner_treedef
)
flat_array_likes
+=
flat_array_like
num_arrays
.
append
(
len
(
flat_array_like
))
...
...
@@ -206,7 +206,7 @@ class StructOfArray:
for
num_array
,
inner_treedef
,
array_field
in
zip
(
num_arrays
,
inner_treedefs
,
array_fields
):
value_dict
[
array_field
]
=
jax
.
tree_unflatten
(
value_dict
[
array_field
]
=
jax
.
tree_
util
.
tree_
unflatten
(
inner_treedef
,
data
[
array_start
:
array_start
+
num_array
])
array_start
+=
num_array
metadata_fields
=
get_metadata_fields
(
new_cls
)
...
...
alphafold/model/mapping.py
View file @
9b18d6a9
...
...
@@ -47,11 +47,11 @@ def _maybe_get_size(array, axis):
def
_expand_axes
(
axes
,
values
,
name
=
'sharded_apply'
):
values_tree_def
=
jax
.
tree_flatten
(
values
)[
1
]
values_tree_def
=
jax
.
tree_
util
.
tree_
flatten
(
values
)[
1
]
flat_axes
=
jax
.
api_util
.
flatten_axes
(
name
,
values_tree_def
,
axes
)
# Replace None's with PROXY
flat_axes
=
[
PROXY
if
x
is
None
else
x
for
x
in
flat_axes
]
return
jax
.
tree_unflatten
(
values_tree_def
,
flat_axes
)
return
jax
.
tree_
util
.
tree_
unflatten
(
values_tree_def
,
flat_axes
)
def
sharded_map
(
...
...
@@ -126,7 +126,7 @@ def sharded_apply(
in_axes_
=
_expand_axes
(
in_axes
,
args
)
in_sizes
=
jax
.
tree_map
(
_maybe_get_size
,
args
,
in_axes_
)
flat_sizes
=
jax
.
tree_flatten
(
in_sizes
)[
0
]
flat_sizes
=
jax
.
tree_
util
.
tree_
flatten
(
in_sizes
)[
0
]
in_size
=
max
(
flat_sizes
)
assert
all
(
i
in
{
in_size
,
-
1
}
for
i
in
flat_sizes
)
...
...
alphafold/model/modules.py
View file @
9b18d6a9
...
...
@@ -501,7 +501,7 @@ class Transition(hk.Module):
num_intermediate
=
int
(
nc
*
self
.
config
.
num_intermediate_factor
)
mask
=
jnp
.
expand_dims
(
mask
,
axis
=-
1
)
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
...
...
@@ -569,12 +569,15 @@ class Attention(hk.Module):
q_weights
=
hk
.
get_parameter
(
'query_w'
,
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
key_dim
),
dtype
=
q_data
.
dtype
,
init
=
glorot_uniform
())
k_weights
=
hk
.
get_parameter
(
'key_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
num_head
,
key_dim
),
dtype
=
q_data
.
dtype
,
init
=
glorot_uniform
())
v_weights
=
hk
.
get_parameter
(
'value_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
num_head
,
value_dim
),
dtype
=
q_data
.
dtype
,
init
=
glorot_uniform
())
q
=
jnp
.
einsum
(
'bqa,ahc->bqhc'
,
q_data
,
q_weights
)
*
key_dim
**
(
-
0.5
)
...
...
@@ -595,10 +598,12 @@ class Attention(hk.Module):
gating_weights
=
hk
.
get_parameter
(
'gating_w'
,
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
value_dim
),
dtype
=
q_data
.
dtype
,
init
=
hk
.
initializers
.
Constant
(
0.0
))
gating_bias
=
hk
.
get_parameter
(
'gating_b'
,
shape
=
(
num_head
,
value_dim
),
dtype
=
q_data
.
dtype
,
init
=
hk
.
initializers
.
Constant
(
1.0
))
gate_values
=
jnp
.
einsum
(
'bqc, chv->bqhv'
,
q_data
,
...
...
@@ -610,9 +615,12 @@ class Attention(hk.Module):
o_weights
=
hk
.
get_parameter
(
'output_w'
,
shape
=
(
num_head
,
value_dim
,
self
.
output_dim
),
dtype
=
q_data
.
dtype
,
init
=
init
)
o_bias
=
hk
.
get_parameter
(
'output_b'
,
shape
=
(
self
.
output_dim
,),
init
=
hk
.
initializers
.
Constant
(
0.0
))
o_bias
=
hk
.
get_parameter
(
'output_b'
,
shape
=
(
self
.
output_dim
,),
dtype
=
q_data
.
dtype
,
init
=
hk
.
initializers
.
Constant
(
0.0
))
output
=
jnp
.
einsum
(
'bqhc,hco->bqo'
,
weighted_avg
,
o_weights
)
+
o_bias
...
...
@@ -658,12 +666,15 @@ class GlobalAttention(hk.Module):
q_weights
=
hk
.
get_parameter
(
'query_w'
,
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
key_dim
),
dtype
=
q_data
.
dtype
,
init
=
glorot_uniform
())
k_weights
=
hk
.
get_parameter
(
'key_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
key_dim
),
dtype
=
q_data
.
dtype
,
init
=
glorot_uniform
())
v_weights
=
hk
.
get_parameter
(
'value_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
value_dim
),
dtype
=
q_data
.
dtype
,
init
=
glorot_uniform
())
v
=
jnp
.
einsum
(
'bka,ac->bkc'
,
m_data
,
v_weights
)
...
...
@@ -684,18 +695,23 @@ class GlobalAttention(hk.Module):
o_weights
=
hk
.
get_parameter
(
'output_w'
,
shape
=
(
num_head
,
value_dim
,
self
.
output_dim
),
dtype
=
q_data
.
dtype
,
init
=
init
)
o_bias
=
hk
.
get_parameter
(
'output_b'
,
shape
=
(
self
.
output_dim
,),
init
=
hk
.
initializers
.
Constant
(
0.0
))
o_bias
=
hk
.
get_parameter
(
'output_b'
,
shape
=
(
self
.
output_dim
,),
dtype
=
q_data
.
dtype
,
init
=
hk
.
initializers
.
Constant
(
0.0
))
if
self
.
config
.
gating
:
gating_weights
=
hk
.
get_parameter
(
'gating_w'
,
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
value_dim
),
dtype
=
q_data
.
dtype
,
init
=
hk
.
initializers
.
Constant
(
0.0
))
gating_bias
=
hk
.
get_parameter
(
'gating_b'
,
shape
=
(
num_head
,
value_dim
),
dtype
=
q_data
.
dtype
,
init
=
hk
.
initializers
.
Constant
(
1.0
))
gate_values
=
jnp
.
einsum
(
'bqc, chv->bqhv'
,
q_data
,
gating_weights
)
...
...
@@ -745,11 +761,11 @@ class MSARowAttentionWithPairBias(hk.Module):
bias
=
(
1e9
*
(
msa_mask
-
1.
))[:,
None
,
None
,
:]
assert
len
(
bias
.
shape
)
==
4
msa_act
=
hk
.
LayerNorm
(
msa_act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
msa_act
)
pair_act
=
hk
.
LayerNorm
(
pair_act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
...
...
@@ -760,6 +776,7 @@ class MSARowAttentionWithPairBias(hk.Module):
weights
=
hk
.
get_parameter
(
'feat_2d_weights'
,
shape
=
(
pair_act
.
shape
[
-
1
],
c
.
num_head
),
dtype
=
msa_act
.
dtype
,
init
=
hk
.
initializers
.
RandomNormal
(
stddev
=
init_factor
))
nonbatched_bias
=
jnp
.
einsum
(
'qkc,ch->hqk'
,
pair_act
,
weights
)
...
...
@@ -812,7 +829,7 @@ class MSAColumnAttention(hk.Module):
bias
=
(
1e9
*
(
msa_mask
-
1.
))[:,
None
,
None
,
:]
assert
len
(
bias
.
shape
)
==
4
msa_act
=
hk
.
LayerNorm
(
msa_act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
msa_act
)
...
...
@@ -867,7 +884,7 @@ class MSAColumnGlobalAttention(hk.Module):
bias
=
(
1e9
*
(
msa_mask
-
1.
))[:,
None
,
None
,
:]
assert
len
(
bias
.
shape
)
==
4
msa_act
=
hk
.
LayerNorm
(
msa_act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
msa_act
)
...
...
@@ -924,7 +941,7 @@ class TriangleAttention(hk.Module):
bias
=
(
1e9
*
(
pair_mask
-
1.
))[:,
None
,
None
,
:]
assert
len
(
bias
.
shape
)
==
4
pair_act
=
hk
.
LayerNorm
(
pair_act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
pair_act
)
...
...
@@ -932,6 +949,7 @@ class TriangleAttention(hk.Module):
weights
=
hk
.
get_parameter
(
'feat_2d_weights'
,
shape
=
(
pair_act
.
shape
[
-
1
],
c
.
num_head
),
dtype
=
pair_act
.
dtype
,
init
=
hk
.
initializers
.
RandomNormal
(
stddev
=
init_factor
))
nonbatched_bias
=
jnp
.
einsum
(
'qkc,ch->hqk'
,
pair_act
,
weights
)
...
...
@@ -1029,7 +1047,7 @@ class PredictedLDDTHead(hk.Module):
"""
act
=
representations
[
'structure_module'
]
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
...
...
@@ -1251,6 +1269,19 @@ class ExperimentallyResolvedHead(hk.Module):
return
output
def
_layer_norm
(
axis
=-
1
,
name
=
'layer_norm'
):
return
common_modules
.
LayerNorm
(
axis
=
axis
,
create_scale
=
True
,
create_offset
=
True
,
eps
=
1e-5
,
use_fast_variance
=
True
,
scale_init
=
hk
.
initializers
.
Constant
(
1.
),
offset_init
=
hk
.
initializers
.
Constant
(
0.
),
param_axis
=
axis
,
name
=
name
)
class
TriangleMultiplication
(
hk
.
Module
):
"""Triangle multiplication layer ("outgoing" or "incoming").
...
...
@@ -1263,25 +1294,34 @@ class TriangleMultiplication(hk.Module):
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
act
,
mask
,
is_training
=
True
):
def
__call__
(
self
,
left_act
,
left_
mask
,
is_training
=
True
):
"""Builds TriangleMultiplication module.
Arguments:
act: Pair activations, shape [N_res, N_res, c_z]
mask: Pair mask, shape [N_res, N_res].
left_
act: Pair activations, shape [N_res, N_res, c_z]
left_
mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
Returns:
Outputs, same shape/type as act.
Outputs, same shape/type as
left_
act.
"""
del
is_training
if
self
.
config
.
fuse_projection_weights
:
return
self
.
_fused_triangle_multiplication
(
left_act
,
left_mask
)
else
:
return
self
.
_triangle_multiplication
(
left_act
,
left_mask
)
@
hk
.
transparent
def
_triangle_multiplication
(
self
,
left_act
,
left_mask
):
"""Implementation of TriangleMultiplication used in AF2 and AF-M<2.3."""
c
=
self
.
config
gc
=
self
.
global_config
mask
=
mask
[...,
None
]
mask
=
left_
mask
[...,
None
]
act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'layer_norm_input'
)(
act
)
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'layer_norm_input'
)(
left_
act
)
input_act
=
act
left_projection
=
common_modules
.
Linear
(
...
...
@@ -1317,7 +1357,7 @@ class TriangleMultiplication(hk.Module):
# b = left_proj_act and a = right_proj_act
act
=
jnp
.
einsum
(
c
.
equation
,
left_proj_act
,
right_proj_act
)
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
...
...
@@ -1340,6 +1380,50 @@ class TriangleMultiplication(hk.Module):
return
act
@
hk
.
transparent
def
_fused_triangle_multiplication
(
self
,
left_act
,
left_mask
):
"""TriangleMultiplication with fused projection weights."""
mask
=
left_mask
[...,
None
]
c
=
self
.
config
gc
=
self
.
global_config
left_act
=
_layer_norm
(
axis
=-
1
,
name
=
'left_norm_input'
)(
left_act
)
# Both left and right projections are fused into projection.
projection
=
common_modules
.
Linear
(
2
*
c
.
num_intermediate_channel
,
name
=
'projection'
)
proj_act
=
mask
*
projection
(
left_act
)
# Both left + right gate are fused into gate_values.
gate_values
=
common_modules
.
Linear
(
2
*
c
.
num_intermediate_channel
,
name
=
'gate'
,
bias_init
=
1.
,
initializer
=
utils
.
final_init
(
gc
))(
left_act
)
proj_act
*=
jax
.
nn
.
sigmoid
(
gate_values
)
left_proj_act
=
proj_act
[:,
:,
:
c
.
num_intermediate_channel
]
right_proj_act
=
proj_act
[:,
:,
c
.
num_intermediate_channel
:]
act
=
jnp
.
einsum
(
c
.
equation
,
left_proj_act
,
right_proj_act
)
act
=
_layer_norm
(
axis
=-
1
,
name
=
'center_norm'
)(
act
)
output_channel
=
int
(
left_act
.
shape
[
-
1
])
act
=
common_modules
.
Linear
(
output_channel
,
initializer
=
utils
.
final_init
(
gc
),
name
=
'output_projection'
)(
act
)
gate_values
=
common_modules
.
Linear
(
output_channel
,
bias_init
=
1.
,
initializer
=
utils
.
final_init
(
gc
),
name
=
'gating_linear'
)(
left_act
)
act
*=
jax
.
nn
.
sigmoid
(
gate_values
)
return
act
class
DistogramHead
(
hk
.
Module
):
"""Head to predict a distogram.
...
...
@@ -1446,7 +1530,7 @@ class OuterProductMean(hk.Module):
c
=
self
.
config
mask
=
mask
[...,
None
]
act
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
'layer_norm_input'
)(
act
)
act
=
common_modules
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
'layer_norm_input'
)(
act
)
left_act
=
mask
*
common_modules
.
Linear
(
c
.
num_outer_channel
,
...
...
@@ -1469,9 +1553,11 @@ class OuterProductMean(hk.Module):
'output_w'
,
shape
=
(
c
.
num_outer_channel
,
c
.
num_outer_channel
,
self
.
num_output_channel
),
dtype
=
act
.
dtype
,
init
=
init_w
)
output_b
=
hk
.
get_parameter
(
'output_b'
,
shape
=
(
self
.
num_output_channel
,),
dtype
=
act
.
dtype
,
init
=
hk
.
initializers
.
Constant
(
0.0
))
def
compute_chunk
(
left_act
):
...
...
@@ -1738,7 +1824,7 @@ class EmbeddingsAndEvoformer(hk.Module):
dgram
)
if
c
.
recycle_features
:
prev_msa_first_row
=
hk
.
LayerNorm
(
prev_msa_first_row
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
...
...
@@ -1746,7 +1832,7 @@ class EmbeddingsAndEvoformer(hk.Module):
batch
[
'prev_msa_first_row'
])
msa_activations
=
msa_activations
.
at
[
0
].
add
(
prev_msa_first_row
)
pair_activations
+=
hk
.
LayerNorm
(
pair_activations
+=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
...
...
@@ -2020,7 +2106,7 @@ class SingleTemplateEmbedding(hk.Module):
self
.
config
.
template_pair_stack
,
self
.
global_config
)(
act
,
mask_2d
,
is_training
)
act
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
'output_layer_norm'
)(
act
)
act
=
common_modules
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
'output_layer_norm'
)(
act
)
return
act
...
...
alphafold/model/modules_multimer.py
View file @
9b18d6a9
...
...
@@ -475,20 +475,51 @@ class AlphaFold(hk.Module):
# Eval mode or tests: use the maximum number of iterations.
num_iter
=
c
.
num_recycle
def
recycle_body
(
i
,
x
):
del
i
prev
,
safe_key
=
x
def
distances
(
points
):
"""Compute all pairwise distances for a set of points."""
return
jnp
.
sqrt
(
jnp
.
sum
((
points
[:,
None
]
-
points
[
None
,
:])
**
2
,
axis
=-
1
))
def
recycle_body
(
x
):
i
,
_
,
prev
,
safe_key
=
x
safe_key1
,
safe_key2
=
safe_key
.
split
()
if
c
.
resample_msa_in_recycling
else
safe_key
.
duplicate
()
# pylint: disable=line-too-long
ret
=
apply_network
(
prev
=
prev
,
safe_key
=
safe_key2
)
return
get_prev
(
ret
),
safe_key1
prev
,
safe_key
=
hk
.
fori_loop
(
0
,
num_iter
,
recycle_body
,
(
prev
,
safe_key
))
return
i
+
1
,
prev
,
get_prev
(
ret
),
safe_key1
def
recycle_cond
(
x
):
i
,
prev
,
next_in
,
_
=
x
ca_idx
=
residue_constants
.
atom_order
[
'CA'
]
sq_diff
=
jnp
.
square
(
distances
(
prev
[
'prev_pos'
][:,
ca_idx
,
:])
-
distances
(
next_in
[
'prev_pos'
][:,
ca_idx
,
:]))
mask
=
batch
[
'seq_mask'
][:,
None
]
*
batch
[
'seq_mask'
][
None
,
:]
sq_diff
=
utils
.
mask_mean
(
mask
,
sq_diff
)
# Early stopping criteria based on criteria used in
# AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
diff
=
jnp
.
sqrt
(
sq_diff
+
1e-8
)
# avoid bad numerics giving negatives
less_than_max_recycles
=
(
i
<
num_iter
)
has_exceeded_tolerance
=
(
(
i
==
0
)
|
(
diff
>
c
.
recycle_early_stop_tolerance
))
return
less_than_max_recycles
&
has_exceeded_tolerance
if
hk
.
running_init
():
num_recycles
,
_
,
prev
,
safe_key
=
recycle_body
(
(
0
,
prev
,
prev
,
safe_key
))
else
:
num_recycles
,
_
,
prev
,
safe_key
=
hk
.
while_loop
(
recycle_cond
,
recycle_body
,
(
0
,
prev
,
prev
,
safe_key
))
else
:
# No recycling.
num_recycles
=
0
# Run extra iteration.
ret
=
apply_network
(
prev
=
prev
,
safe_key
=
safe_key
)
if
not
return_representations
:
del
ret
[
'representations'
]
ret
[
'num_recycles'
]
=
num_recycles
return
ret
...
...
@@ -524,11 +555,13 @@ class EmbeddingsAndEvoformer(hk.Module):
Feature embedding using the features as described before.
"""
c
=
self
.
config
gc
=
self
.
global_config
rel_feats
=
[]
pos
=
batch
[
'residue_index'
]
asym_id
=
batch
[
'asym_id'
]
asym_id_same
=
jnp
.
equal
(
asym_id
[:,
None
],
asym_id
[
None
,
:])
offset
=
pos
[:,
None
]
-
pos
[
None
,
:]
dtype
=
jnp
.
bfloat16
if
gc
.
bfloat16
else
jnp
.
float32
clipped_offset
=
jnp
.
clip
(
offset
+
c
.
max_relative_idx
,
a_min
=
0
,
a_max
=
2
*
c
.
max_relative_idx
)
...
...
@@ -568,6 +601,7 @@ class EmbeddingsAndEvoformer(hk.Module):
rel_feat
=
jnp
.
concatenate
(
rel_feats
,
axis
=-
1
)
rel_feat
=
rel_feat
.
astype
(
dtype
)
return
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'position_activations'
)(
...
...
@@ -579,6 +613,7 @@ class EmbeddingsAndEvoformer(hk.Module):
gc
=
self
.
global_config
batch
=
dict
(
batch
)
dtype
=
jnp
.
bfloat16
if
gc
.
bfloat16
else
jnp
.
float32
if
safe_key
is
None
:
safe_key
=
prng
.
SafeKey
(
hk
.
next_rng_key
())
...
...
@@ -587,177 +622,178 @@ class EmbeddingsAndEvoformer(hk.Module):
batch
[
'msa_profile'
]
=
make_msa_profile
(
batch
)
target_feat
=
jax
.
nn
.
one_hot
(
batch
[
'aatype'
],
21
)
preprocess_1d
=
common_modules
.
Linear
(
c
.
msa_channel
,
name
=
'preprocess_1d'
)(
target_feat
)
safe_key
,
sample_key
,
mask_key
=
safe_key
.
split
(
3
)
batch
=
sample_msa
(
sample_key
,
batch
,
c
.
num_msa
)
batch
=
make_masked_msa
(
batch
,
mask_key
,
c
.
masked_msa
)
(
batch
[
'cluster_profile'
],
batch
[
'cluster_deletion_mean'
])
=
nearest_neighbor_clusters
(
batch
)
msa_feat
=
create_msa_feat
(
batch
)
preprocess_msa
=
common_modules
.
Linear
(
c
.
msa_channel
,
name
=
'preprocess_msa'
)(
msa_feat
)
msa_activations
=
jnp
.
expand_dims
(
preprocess_1d
,
axis
=
0
)
+
preprocess_msa
left_single
=
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'left_single'
)(
target_feat
)
right_single
=
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'right_single'
)(
target_feat
)
pair_activations
=
left_single
[:,
None
]
+
right_single
[
None
]
mask_2d
=
batch
[
'seq_mask'
][:,
None
]
*
batch
[
'seq_mask'
][
None
,
:]
mask_2d
=
mask_2d
.
astype
(
jnp
.
float32
)
if
c
.
recycle_pos
:
prev_pseudo_beta
=
modules
.
pseudo_beta_fn
(
batch
[
'aatype'
],
batch
[
'prev_pos'
],
None
)
dgram
=
modules
.
dgram_from_positions
(
prev_pseudo_beta
,
**
self
.
config
.
prev_pos
)
pair_activations
+=
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'prev_pos_linear'
)(
dgram
)
if
c
.
recycle_features
:
prev_msa_first_row
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'prev_msa_first_row_norm'
)(
batch
[
'prev_msa_first_row'
])
msa_activations
=
msa_activations
.
at
[
0
].
add
(
prev_msa_first_row
)
pair_activations
+=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'prev_pair_norm'
)(
batch
[
'prev_pair'
])
if
c
.
max_relative_idx
:
pair_activations
+=
self
.
_relative_encoding
(
batch
)
if
c
.
template
.
enabled
:
template_module
=
TemplateEmbedding
(
c
.
template
,
gc
)
template_batch
=
{
'template_aatype'
:
batch
[
'template_aatype'
],
'template_all_atom_positions'
:
batch
[
'template_all_atom_positions'
],
'template_all_atom_mask'
:
batch
[
'template_all_atom_mask'
]
with
utils
.
bfloat16_context
():
target_feat
=
jax
.
nn
.
one_hot
(
batch
[
'aatype'
],
21
).
astype
(
dtype
)
preprocess_1d
=
common_modules
.
Linear
(
c
.
msa_channel
,
name
=
'preprocess_1d'
)(
target_feat
)
safe_key
,
sample_key
,
mask_key
=
safe_key
.
split
(
3
)
batch
=
sample_msa
(
sample_key
,
batch
,
c
.
num_msa
)
batch
=
make_masked_msa
(
batch
,
mask_key
,
c
.
masked_msa
)
(
batch
[
'cluster_profile'
],
batch
[
'cluster_deletion_mean'
])
=
nearest_neighbor_clusters
(
batch
)
msa_feat
=
create_msa_feat
(
batch
).
astype
(
dtype
)
preprocess_msa
=
common_modules
.
Linear
(
c
.
msa_channel
,
name
=
'preprocess_msa'
)(
msa_feat
)
msa_activations
=
jnp
.
expand_dims
(
preprocess_1d
,
axis
=
0
)
+
preprocess_msa
left_single
=
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'left_single'
)(
target_feat
)
right_single
=
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'right_single'
)(
target_feat
)
pair_activations
=
left_single
[:,
None
]
+
right_single
[
None
]
mask_2d
=
batch
[
'seq_mask'
][:,
None
]
*
batch
[
'seq_mask'
][
None
,
:]
mask_2d
=
mask_2d
.
astype
(
dtype
)
if
c
.
recycle_pos
:
prev_pseudo_beta
=
modules
.
pseudo_beta_fn
(
batch
[
'aatype'
],
batch
[
'prev_pos'
],
None
)
dgram
=
modules
.
dgram_from_positions
(
prev_pseudo_beta
,
**
self
.
config
.
prev_pos
)
dgram
=
dgram
.
astype
(
dtype
)
pair_activations
+=
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'prev_pos_linear'
)(
dgram
)
if
c
.
recycle_features
:
prev_msa_first_row
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'prev_msa_first_row_norm'
)(
batch
[
'prev_msa_first_row'
]).
astype
(
dtype
)
msa_activations
=
msa_activations
.
at
[
0
].
add
(
prev_msa_first_row
)
pair_activations
+=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'prev_pair_norm'
)(
batch
[
'prev_pair'
]).
astype
(
dtype
)
if
c
.
max_relative_idx
:
pair_activations
+=
self
.
_relative_encoding
(
batch
)
if
c
.
template
.
enabled
:
template_module
=
TemplateEmbedding
(
c
.
template
,
gc
)
template_batch
=
{
'template_aatype'
:
batch
[
'template_aatype'
],
'template_all_atom_positions'
:
batch
[
'template_all_atom_positions'
],
'template_all_atom_mask'
:
batch
[
'template_all_atom_mask'
]
}
# Construct a mask such that only intra-chain template features are
# computed, since all templates are for each chain individually.
multichain_mask
=
batch
[
'asym_id'
][:,
None
]
==
batch
[
'asym_id'
][
None
,
:]
safe_key
,
safe_subkey
=
safe_key
.
split
()
template_act
=
template_module
(
query_embedding
=
pair_activations
,
template_batch
=
template_batch
,
padding_mask_2d
=
mask_2d
,
multichain_mask_2d
=
multichain_mask
,
is_training
=
is_training
,
safe_key
=
safe_subkey
)
pair_activations
+=
template_act
# Extra MSA stack.
(
extra_msa_feat
,
extra_msa_mask
)
=
create_extra_msa_feature
(
batch
,
c
.
num_extra_msa
)
extra_msa_activations
=
common_modules
.
Linear
(
c
.
extra_msa_channel
,
name
=
'extra_msa_activations'
)(
extra_msa_feat
).
astype
(
dtype
)
extra_msa_mask
=
extra_msa_mask
.
astype
(
dtype
)
extra_evoformer_input
=
{
'msa'
:
extra_msa_activations
,
'pair'
:
pair_activations
,
}
# Construct a mask such that only intra-chain template features are
# computed, since all templates are for each chain individually.
multichain_mask
=
batch
[
'asym_id'
][:,
None
]
==
batch
[
'asym_id'
][
None
,
:]
safe_key
,
safe_subkey
=
safe_key
.
split
()
template_act
=
template_module
(
query_embedding
=
pair_activations
,
template_batch
=
template_batch
,
padding_mask_2d
=
mask_2d
,
multichain_mask_2d
=
multichain_mask
,
is_training
=
is_training
,
safe_key
=
safe_subkey
)
pair_activations
+=
template_act
# Extra MSA stack.
(
extra_msa_feat
,
extra_msa_mask
)
=
create_extra_msa_feature
(
batch
,
c
.
num_extra_msa
)
extra_msa_activations
=
common_modules
.
Linear
(
c
.
extra_msa_channel
,
name
=
'extra_msa_activations'
)(
extra_msa_feat
)
extra_msa_mask
=
extra_msa_mask
.
astype
(
jnp
.
float32
)
extra_evoformer_input
=
{
'msa'
:
extra_msa_activations
,
'pair'
:
pair_activations
,
}
extra_masks
=
{
'msa'
:
extra_msa_mask
,
'pair'
:
mask_2d
}
extra_masks
=
{
'msa'
:
extra_msa_mask
,
'pair'
:
mask_2d
}
extra_evoformer_iteration
=
modules
.
EvoformerIteration
(
c
.
evoformer
,
gc
,
is_extra_msa
=
True
,
name
=
'extra_msa_stack'
)
extra_evoformer_iteration
=
modules
.
EvoformerIteration
(
c
.
evoformer
,
gc
,
is_extra_msa
=
True
,
name
=
'extra_msa_stack'
)
def
extra_evoformer_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
extra_evoformer_output
=
extra_evoformer_iteration
(
activations
=
act
,
masks
=
extra_masks
,
is_training
=
is_training
,
safe_key
=
safe_subkey
)
return
(
extra_evoformer_output
,
safe_key
)
if
gc
.
use_remat
:
extra_evoformer_fn
=
hk
.
remat
(
extra_evoformer_fn
)
safe_key
,
safe_subkey
=
safe_key
.
split
()
extra_evoformer_stack
=
layer_stack
.
layer_stack
(
c
.
extra_msa_stack_num_block
)(
extra_evoformer_fn
)
extra_evoformer_output
,
safe_key
=
extra_evoformer_stack
(
(
extra_evoformer_input
,
safe_subkey
))
pair_activations
=
extra_evoformer_output
[
'pair'
]
# Get the size of the MSA before potentially adding templates, so we
# can crop out the templates later.
num_msa_sequences
=
msa_activations
.
shape
[
0
]
evoformer_input
=
{
'msa'
:
msa_activations
,
'pair'
:
pair_activations
,
}
evoformer_masks
=
{
'msa'
:
batch
[
'msa_mask'
].
astype
(
jnp
.
float32
),
'pair'
:
mask_2d
}
if
c
.
template
.
enabled
:
template_features
,
template_masks
=
(
template_embedding_1d
(
batch
=
batch
,
num_channel
=
c
.
msa_channel
))
evoformer_input
[
'msa'
]
=
jnp
.
concatenate
(
[
evoformer_input
[
'msa'
],
template_features
],
axis
=
0
)
evoformer_masks
[
'msa'
]
=
jnp
.
concatenate
(
[
evoformer_masks
[
'msa'
],
template_masks
],
axis
=
0
)
def
extra_evoformer_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
extra_evoformer_output
=
extra_evoformer_iteration
(
activations
=
act
,
masks
=
extra_masks
,
is_training
=
is_training
,
safe_key
=
safe_subkey
)
return
(
extra_evoformer_output
,
safe_key
)
evoformer_iteration
=
modules
.
EvoformerIteration
(
c
.
evoformer
,
gc
,
is_extra_msa
=
False
,
name
=
'evoformer_iteration'
)
if
gc
.
use_remat
:
extra_
evoformer
_fn
=
hk
.
remat
(
extra_evoformer_fn
)
def
evoformer_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
evoformer_output
=
evoformer_iteration
(
activations
=
act
,
masks
=
evoformer_masks
,
is_training
=
is_training
,
safe_key
=
safe_subkey
)
return
(
evoformer_output
,
safe_key
)
if
gc
.
use_remat
:
evoformer_fn
=
hk
.
remat
(
evoformer_fn
)
extra_evoformer_stack
=
layer_stack
.
layer_stack
(
c
.
extra_msa_stack_num_block
)(
extra_evoformer_fn
)
extra_evoformer_output
,
safe_key
=
extra_evoformer_stack
(
(
extra_evoformer_input
,
safe_subkey
))
pair_activations
=
extra_evoformer_output
[
'pair'
]
# Get the size of the MSA before potentially adding templates, so we
# can crop out the templates later.
num_msa_sequences
=
msa_activations
.
shape
[
0
]
evoformer_input
=
{
'msa'
:
msa_activations
,
'pair'
:
pair_activations
,
}
evoformer_masks
=
{
'msa'
:
batch
[
'msa_mask'
].
astype
(
dtype
),
'pair'
:
mask_2d
}
if
c
.
template
.
enabled
:
template_features
,
template_masks
=
(
template_embedding_1d
(
batch
=
batch
,
num_channel
=
c
.
msa_channel
,
global_config
=
gc
))
evoformer_input
[
'msa'
]
=
jnp
.
concatenate
(
[
evoformer_input
[
'msa'
],
template_features
],
axis
=
0
)
evoformer_masks
[
'msa'
]
=
jnp
.
concatenate
(
[
evoformer_masks
[
'msa'
],
template_masks
],
axis
=
0
)
evoformer_iteration
=
modules
.
EvoformerIteration
(
c
.
evoformer
,
gc
,
is_extra_msa
=
False
,
name
=
'evoformer_iteration'
)
def
evoformer_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
evoformer_output
=
evoformer_iteration
(
activations
=
act
,
masks
=
evoformer_masks
,
is_training
=
is_training
,
safe_key
=
safe_subkey
)
return
(
evoformer_output
,
safe_key
)
if
gc
.
use_remat
:
evoformer_fn
=
hk
.
remat
(
evoformer_fn
)
safe_key
,
safe_subkey
=
safe_key
.
split
()
evoformer_stack
=
layer_stack
.
layer_stack
(
c
.
evoformer_num_block
)(
evoformer_fn
)
safe_key
,
safe_subkey
=
safe_key
.
split
()
evoformer_stack
=
layer_stack
.
layer_stack
(
c
.
evoformer_num_block
)(
evoformer_fn
)
def
run_evoformer
(
evoformer_input
):
evoformer_output
,
_
=
evoformer_stack
((
evoformer_input
,
safe_subkey
))
return
evoformer_output
def
run_evoformer
(
evoformer_input
):
evoformer_output
,
_
=
evoformer_stack
((
evoformer_input
,
safe_subkey
))
return
evoformer_output
evoformer_output
=
run_evoformer
(
evoformer_input
)
evoformer_output
=
run_evoformer
(
evoformer_input
)
msa_activations
=
evoformer_output
[
'msa'
]
pair_activations
=
evoformer_output
[
'pair'
]
msa_activations
=
evoformer_output
[
'msa'
]
pair_activations
=
evoformer_output
[
'pair'
]
single_activations
=
common_modules
.
Linear
(
c
.
seq_channel
,
name
=
'single_activations'
)(
msa_activations
[
0
])
single_activations
=
common_modules
.
Linear
(
c
.
seq_channel
,
name
=
'single_activations'
)(
msa_activations
[
0
])
output
.
update
({
'single'
:
...
...
@@ -771,6 +807,12 @@ class EmbeddingsAndEvoformer(hk.Module):
msa_activations
[
0
],
})
# Convert back to float32 if we're not saving memory.
if
not
gc
.
bfloat16_output
:
for
k
,
v
in
output
.
items
():
if
v
.
dtype
==
jnp
.
bfloat16
:
output
[
k
]
=
v
.
astype
(
jnp
.
float32
)
return
output
...
...
@@ -917,6 +959,9 @@ class SingleTemplateEmbedding(hk.Module):
# backbone affine - i.e. in each residues local frame, what direction are
# each of the other residues.
raw_atom_pos
=
template_all_atom_positions
if
gc
.
bfloat16
:
# Vec3Arrays are required to be float32
raw_atom_pos
=
raw_atom_pos
.
astype
(
jnp
.
float32
)
atom_pos
=
geometry
.
Vec3Array
.
from_array
(
raw_atom_pos
)
rigid
,
backbone_mask
=
folding_multimer
.
make_backbone_affine
(
...
...
@@ -928,6 +973,10 @@ class SingleTemplateEmbedding(hk.Module):
unit_vector
=
rigid_vec
.
normalized
()
unit_vector
=
[
unit_vector
.
x
,
unit_vector
.
y
,
unit_vector
.
z
]
if
gc
.
bfloat16
:
unit_vector
=
[
x
.
astype
(
jnp
.
bfloat16
)
for
x
in
unit_vector
]
backbone_mask
=
backbone_mask
.
astype
(
jnp
.
bfloat16
)
backbone_mask_2d
=
backbone_mask
[:,
None
]
*
backbone_mask
[
None
,
:]
backbone_mask_2d
*=
multichain_mask_2d
unit_vector
=
[
x
*
backbone_mask_2d
for
x
in
unit_vector
]
...
...
@@ -937,7 +986,7 @@ class SingleTemplateEmbedding(hk.Module):
to_concat
.
extend
([(
x
,
0
)
for
x
in
unit_vector
])
to_concat
.
append
((
backbone_mask_2d
,
0
))
query_embedding
=
hk
.
LayerNorm
(
query_embedding
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
...
...
@@ -986,12 +1035,13 @@ class SingleTemplateEmbedding(hk.Module):
template_iteration_fn
)
act
,
safe_key
=
template_stack
((
act
,
safe_subkey
))
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'output_layer_norm'
)(
act
)
return
act
...
...
@@ -1044,21 +1094,18 @@ class TemplateEmbeddingIteration(hk.Module):
act
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
act
=
dropout_wrapper_fn
(
modules
.
TriangleAttention
(
c
.
triangle_attention_starting_node
,
gc
,
name
=
'triangle_attention_starting_node'
),
act
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
act
=
dropout_wrapper_fn
(
modules
.
TriangleAttention
(
c
.
triangle_attention_ending_node
,
gc
,
name
=
'triangle_attention_ending_node'
),
act
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
act
=
dropout_wrapper_fn
(
modules
.
Transition
(
c
.
pair_transition
,
gc
,
name
=
'pair_transition'
),
...
...
@@ -1069,7 +1116,7 @@ class TemplateEmbeddingIteration(hk.Module):
return
act
def
template_embedding_1d
(
batch
,
num_channel
):
def
template_embedding_1d
(
batch
,
num_channel
,
global_config
):
"""Embed templates into an (num_res, num_templates, num_channels) embedding.
Args:
...
...
@@ -1080,6 +1127,7 @@ def template_embedding_1d(batch, num_channel):
template_all_atom_mask, (num_templates, num_residues, 37) atom mask for
each template.
num_channel: The number of channels in the output.
global_config: The global_config.
Returns:
An embedding of shape (num_templates, num_res, num_channels) and a mask of
...
...
@@ -1112,6 +1160,10 @@ def template_embedding_1d(batch, num_channel):
template_mask
=
chi_mask
[:,
:,
0
]
if
global_config
.
bfloat16
:
template_features
=
template_features
.
astype
(
jnp
.
bfloat16
)
template_mask
=
template_mask
.
astype
(
jnp
.
bfloat16
)
template_activations
=
common_modules
.
Linear
(
num_channel
,
initializer
=
'relu'
,
...
...
alphafold/model/utils.py
View file @
9b18d6a9
...
...
@@ -15,6 +15,7 @@
"""A collection of JAX utility functions for use in protein folding."""
import
collections
import
contextlib
import
functools
import
numbers
from
typing
import
Mapping
...
...
@@ -25,6 +26,27 @@ import jax.numpy as jnp
import
numpy
as
np
def
bfloat16_creator
(
next_creator
,
shape
,
dtype
,
init
,
context
):
"""Creates float32 variables when bfloat16 is requested."""
if
context
.
original_dtype
==
jnp
.
bfloat16
:
dtype
=
jnp
.
float32
return
next_creator
(
shape
,
dtype
,
init
)
def
bfloat16_getter
(
next_getter
,
value
,
context
):
"""Casts float32 to bfloat16 when bfloat16 was originally requested."""
if
context
.
original_dtype
==
jnp
.
bfloat16
:
assert
value
.
dtype
==
jnp
.
float32
value
=
value
.
astype
(
jnp
.
bfloat16
)
return
next_getter
(
value
)
@
contextlib
.
contextmanager
def
bfloat16_context
():
with
hk
.
custom_creator
(
bfloat16_creator
),
hk
.
custom_getter
(
bfloat16_getter
):
yield
def
final_init
(
config
):
if
config
.
zero_init
:
return
'zeros'
...
...
alphafold/notebooks/notebook_utils.py
View file @
9b18d6a9
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
"""Helper methods for the AlphaFold Colab notebook."""
import
enum
import
json
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
,
Tuple
...
...
@@ -23,13 +22,7 @@ from matplotlib import pyplot as plt
import
numpy
as
np
@
enum
.
unique
class
ModelType
(
enum
.
Enum
):
MONOMER
=
0
MULTIMER
=
1
def
clean_and_validate_sequence
(
def
clean_and_validate_single_sequence
(
input_sequence
:
str
,
min_length
:
int
,
max_length
:
int
)
->
str
:
"""Checks that the input sequence is ok and returns a clean version of it."""
# Remove all whitespaces, tabs and end lines; upper-case.
...
...
@@ -54,41 +47,23 @@ def clean_and_validate_sequence(
return
clean_sequence
def
validate_input
(
def
clean_and_
validate_input
_sequences
(
input_sequences
:
Sequence
[
str
],
min_length
:
int
,
max_length
:
int
,
max_multimer_length
:
int
)
->
Tuple
[
Sequence
[
str
],
ModelType
]:
"""Validates and cleans input sequences and determines which model to use."""
min_sequence_length
:
int
,
max_sequence_length
:
int
)
->
Sequence
[
str
]:
"""Validates and cleans input sequences."""
sequences
=
[]
for
input_sequence
in
input_sequences
:
if
input_sequence
.
strip
():
input_sequence
=
clean_and_validate_sequence
(
input_sequence
=
clean_and_validate_
single_
sequence
(
input_sequence
=
input_sequence
,
min_length
=
min_length
,
max_length
=
max_length
)
min_length
=
min_
sequence_
length
,
max_length
=
max_
sequence_
length
)
sequences
.
append
(
input_sequence
)
if
len
(
sequences
)
==
1
:
print
(
'Using the single-chain model.'
)
return
sequences
,
ModelType
.
MONOMER
elif
len
(
sequences
)
>
1
:
total_multimer_length
=
sum
([
len
(
seq
)
for
seq
in
sequences
])
if
total_multimer_length
>
max_multimer_length
:
raise
ValueError
(
f
'The total length of multimer sequences is too long: '
f
'
{
total_multimer_length
}
, while the maximum is '
f
'
{
max_multimer_length
}
. Please use the full AlphaFold '
f
'system for long multimers.'
)
elif
total_multimer_length
>
1536
:
print
(
'WARNING: The accuracy of the system has not been fully validated '
'above 1536 residues, and you may experience long running times or '
f
'run out of memory for your complex with
{
total_multimer_length
}
'
'residues.'
)
print
(
f
'Using the multimer model with
{
len
(
sequences
)
}
sequences.'
)
return
sequences
,
ModelType
.
MULTIMER
if
sequences
:
return
sequences
else
:
raise
ValueError
(
'No input amino acid sequence provided, please provide at '
'least one sequence.'
)
...
...
alphafold/notebooks/notebook_utils_test.py
View file @
9b18d6a9
...
...
@@ -85,7 +85,7 @@ class NotebookUtilsTest(parameterized.TestCase):
(
'DeepMind'
,
'DEEPMIND'
),
(
'A '
,
'A'
),
(
'
\t
A'
,
'A'
),
(
' A
\t\n
'
,
'A'
),
(
'ACDEFGHIKLMNPQRSTVWY'
,
'ACDEFGHIKLMNPQRSTVWY'
))
def
test_clean_and_validate_sequence_ok
(
self
,
sequence
,
exp_clean
):
clean
=
notebook_utils
.
clean_and_validate_sequence
(
clean
=
notebook_utils
.
clean_and_validate_
single_
sequence
(
sequence
,
min_length
=
1
,
max_length
=
100
)
self
.
assertEqual
(
clean
,
exp_clean
)
...
...
@@ -100,35 +100,29 @@ class NotebookUtilsTest(parameterized.TestCase):
(
'bad_amino_acids_Z'
,
'ZZZZ'
,
'non-amino acid'
))
def
test_clean_and_validate_sequence_bad
(
self
,
sequence
,
exp_error
):
with
self
.
assertRaisesRegex
(
ValueError
,
f
'.*
{
exp_error
}
.*'
):
notebook_utils
.
clean_and_validate_sequence
(
notebook_utils
.
clean_and_validate_
single_
sequence
(
sequence
,
min_length
=
4
,
max_length
=
8
)
@
parameterized
.
parameters
(
([
'A'
,
''
,
''
,
' '
,
'
\t
'
,
'
\t\n
'
,
''
,
''
],
[
'A'
],
notebook_utils
.
ModelType
.
MONOMER
),
([
''
,
'A'
],
[
'A'
],
notebook_utils
.
ModelType
.
MONOMER
),
([
'A'
,
'C '
,
''
],
[
'A'
,
'C'
],
notebook_utils
.
ModelType
.
MULTIMER
),
([
''
,
'A'
,
''
,
'C '
],
[
'A'
,
'C'
],
notebook_utils
.
ModelType
.
MULTIMER
))
def
test_validate_input_ok
(
self
,
input_sequences
,
exp_sequences
,
exp_model_type
):
sequences
,
model_type
=
notebook_utils
.
validate_input
(
([
'A'
,
''
,
''
,
' '
,
'
\t
'
,
'
\t\n
'
,
''
,
''
],
[
'A'
]),
([
''
,
'A'
],
[
'A'
]),
([
'A'
,
'C '
,
''
],
[
'A'
,
'C'
]),
([
''
,
'A'
,
''
,
'C '
],
[
'A'
,
'C'
]))
def
test_validate_input_ok
(
self
,
input_sequences
,
exp_sequences
):
sequences
=
notebook_utils
.
clean_and_validate_input_sequences
(
input_sequences
=
input_sequences
,
min_
length
=
1
,
max
_length
=
1
00
,
max_
multimer
_length
=
100
)
min_
sequence
_length
=
1
,
max_
sequence
_length
=
100
)
self
.
assertSequenceEqual
(
sequences
,
exp_sequences
)
self
.
assertEqual
(
model_type
,
exp_model_type
)
@
parameterized
.
named_parameters
(
(
'no_input_sequence'
,
[
''
,
'
\t
'
,
'
\n
'
],
'No input amino acid sequence'
),
(
'too_long_single'
,
[
'AAAAAAAAA'
,
'AAAA'
],
'Input sequence is too long'
),
(
'too_
long_multimer
'
,
[
'AAA
A
'
,
'AAAA
A
'
],
'
The total length of multimer
'
))
(
'too_
short_single
'
,
[
'AAA'
,
'AAAA'
],
'
Input sequence is too short
'
))
def
test_validate_input_bad
(
self
,
input_sequences
,
exp_error
):
with
self
.
assertRaisesRegex
(
ValueError
,
f
'.*
{
exp_error
}
.*'
):
notebook_utils
.
validate_input
(
input_sequences
=
input_sequences
,
m
in_length
=
4
,
max_length
=
8
,
max_multimer
_length
=
6
)
notebook_utils
.
clean_and_
validate_input
_sequences
(
input_sequences
=
input_sequences
,
min_sequence_length
=
4
,
m
ax_sequence
_length
=
8
)
def
test_merge_chunked_msa_no_hits
(
self
):
results
=
[
ONLY_QUERY_HIT
,
ONLY_QUERY_HIT
]
...
...
alphafold/relax/relax.py
View file @
9b18d6a9
...
...
@@ -56,7 +56,8 @@ class AmberRelaxation(object):
self
.
_use_gpu
=
use_gpu
def
process
(
self
,
*
,
prot
:
protein
.
Protein
)
->
Tuple
[
str
,
Dict
[
str
,
Any
],
np
.
ndarray
]:
prot
:
protein
.
Protein
)
->
Tuple
[
str
,
Dict
[
str
,
Any
],
Sequence
[
float
]]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out
=
amber_minimize
.
run_pipeline
(
prot
=
prot
,
max_iterations
=
self
.
_max_iterations
,
...
...
@@ -73,12 +74,11 @@ class AmberRelaxation(object):
'attempts'
:
out
[
'min_attempts'
],
'rmsd'
:
rmsd
}
pdb_str
=
amber_minimize
.
clean_protein
(
prot
)
min_pdb
=
utils
.
overwrite_pdb_coordinates
(
pdb_str
,
min_pos
)
min_pdb
=
out
[
'min_pdb'
]
min_pdb
=
utils
.
overwrite_b_factors
(
min_pdb
,
prot
.
b_factors
)
utils
.
assert_equal_nonterminal_atom_types
(
protein
.
from_pdb_string
(
min_pdb
).
atom_mask
,
prot
.
atom_mask
)
violations
=
out
[
'structural_violations'
][
'total_per_residue_violations_mask'
]
'total_per_residue_violations_mask'
]
.
tolist
()
return
min_pdb
,
debug_data
,
violations
alphafold/relax/relax_test.py
View file @
9b18d6a9
...
...
@@ -82,7 +82,7 @@ class RunAmberRelaxTest(absltest.TestCase):
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
0
,
0
,
0
,
0
,
0
])
# Check no violations were added. Can't check exactly due to stochasticity.
self
.
assertTrue
(
np
.
all
(
num_violations
<=
exp_num_violations
))
self
.
assertTrue
(
np
.
all
(
np
.
array
(
num_violations
)
<=
exp_num_violations
))
if
__name__
==
'__main__'
:
...
...
alphafold/relax/utils.py
View file @
9b18d6a9
...
...
@@ -17,17 +17,6 @@ import io
from
alphafold.common
import
residue_constants
from
Bio
import
PDB
import
numpy
as
np
from
simtk.openmm
import
app
as
openmm_app
from
simtk.openmm.app.internal.pdbstructure
import
PdbStructure
def
overwrite_pdb_coordinates
(
pdb_str
:
str
,
pos
)
->
str
:
pdb_file
=
io
.
StringIO
(
pdb_str
)
structure
=
PdbStructure
(
pdb_file
)
topology
=
openmm_app
.
PDBFile
(
structure
).
getTopology
()
with
io
.
StringIO
()
as
f
:
openmm_app
.
PDBFile
.
writeFile
(
topology
,
pos
,
f
)
return
f
.
getvalue
()
def
overwrite_b_factors
(
pdb_str
:
str
,
bfactors
:
np
.
ndarray
)
->
str
:
...
...
docker/Dockerfile
View file @
9b18d6a9
...
...
@@ -21,7 +21,7 @@ ARG CUDA
# Use bash to support string substitution.
SHELL
["/bin/bash", "-o", "pipefail", "-c"]
RUN
apt-get update
\
RUN
apt-get update
\
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
--no-install-recommends
-y
\
build-essential
\
cmake
\
...
...
@@ -59,7 +59,7 @@ RUN conda install -qy conda==4.13.0 \
cudatoolkit
==
${
CUDA_VERSION
}
\
pdbfixer
\
pip
\
python
=
3.
7
\
python
=
3.
8
\
&&
conda clean
--all
--force-pkgs-dirs
--yes
COPY
. /app/alphafold
...
...
@@ -75,7 +75,7 @@ RUN pip3 install --upgrade pip --no-cache-dir \
-f
https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Apply OpenMM patch.
WORKDIR
/opt/conda/lib/python3.
7
/site-packages
WORKDIR
/opt/conda/lib/python3.
8
/site-packages
RUN
patch
-p0
< /app/alphafold/docker/openmm.patch
# Add SETUID bit to the ldconfig binary so that non-root users can run it.
...
...
docker/run_docker.py
View file @
9b18d6a9
...
...
@@ -133,7 +133,7 @@ def main(argv):
# Path to the MGnify database for use by JackHMMER.
mgnify_database_path
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'mgnify'
,
'mgy_clusters_20
18_12
.fa'
)
FLAGS
.
data_dir
,
'mgnify'
,
'mgy_clusters_20
22_05
.fa'
)
# Path to the BFD database for use by HHblits.
bfd_database_path
=
os
.
path
.
join
(
...
...
@@ -144,9 +144,9 @@ def main(argv):
small_bfd_database_path
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'small_bfd'
,
'bfd-first_non_consensus_sequences.fasta'
)
# Path to the Uni
clust
30 database for use by HHblits.
uni
clust
30_database_path
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'uni
clust
30'
,
'
u
ni
clust30_2018_08'
,
'uniclust30_2018
_0
8
'
)
# Path to the Uni
ref
30 database for use by HHblits.
uni
ref
30_database_path
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'uni
ref
30'
,
'
U
ni
Ref30_2021
_0
3
'
)
# Path to the PDB70 database for use by HHsearch.
pdb70_database_path
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'pdb70'
,
'pdb70'
)
...
...
@@ -199,7 +199,7 @@ def main(argv):
database_paths
.
append
((
'small_bfd_database_path'
,
small_bfd_database_path
))
else
:
database_paths
.
extend
([
(
'uni
clust
30_database_path'
,
uni
clust
30_database_path
),
(
'uni
ref
30_database_path'
,
uni
ref
30_database_path
),
(
'bfd_database_path'
,
bfd_database_path
),
])
for
name
,
path
in
database_paths
:
...
...
docs/casp15_predictions.zip
0 → 100644
View file @
9b18d6a9
File added
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment