Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
alphafold2_jax
Commits
9b18d6a9
Commit
9b18d6a9
authored
Dec 11, 2022
by
Augustin Zidek
Browse files
Release code for v2.3.0
PiperOrigin-RevId: 494507694
parent
4494af84
Changes
30
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
667 additions
and
430 deletions
+667
-430
README.md
README.md
+130
-112
afdb/README.md
afdb/README.md
+1
-1
alphafold/data/pipeline.py
alphafold/data/pipeline.py
+7
-7
alphafold/model/common_modules.py
alphafold/model/common_modules.py
+61
-0
alphafold/model/config.py
alphafold/model/config.py
+64
-24
alphafold/model/folding.py
alphafold/model/folding.py
+4
-4
alphafold/model/folding_multimer.py
alphafold/model/folding_multimer.py
+4
-4
alphafold/model/geometry/struct_of_array.py
alphafold/model/geometry/struct_of_array.py
+2
-2
alphafold/model/mapping.py
alphafold/model/mapping.py
+3
-3
alphafold/model/modules.py
alphafold/model/modules.py
+109
-23
alphafold/model/modules_multimer.py
alphafold/model/modules_multimer.py
+224
-172
alphafold/model/utils.py
alphafold/model/utils.py
+22
-0
alphafold/notebooks/notebook_utils.py
alphafold/notebooks/notebook_utils.py
+10
-35
alphafold/notebooks/notebook_utils_test.py
alphafold/notebooks/notebook_utils_test.py
+13
-19
alphafold/relax/relax.py
alphafold/relax/relax.py
+4
-4
alphafold/relax/relax_test.py
alphafold/relax/relax_test.py
+1
-1
alphafold/relax/utils.py
alphafold/relax/utils.py
+0
-11
docker/Dockerfile
docker/Dockerfile
+3
-3
docker/run_docker.py
docker/run_docker.py
+5
-5
docs/casp15_predictions.zip
docs/casp15_predictions.zip
+0
-0
No files found.
README.md
View file @
9b18d6a9
This diff is collapsed.
Click to expand it.
afdb/README.md
View file @
9b18d6a9
...
@@ -304,9 +304,9 @@ fractionPlddtVeryHigh | `FLOAT64` | Fraction of the residues in the predi
...
@@ -304,9 +304,9 @@ fractionPlddtVeryHigh | `FLOAT64` | Fraction of the residues in the predi
fractionPlddtVeryLow |
`FLOAT64`
| Fraction of the residues in the prediction with pLDDT less than 50
fractionPlddtVeryLow |
`FLOAT64`
| Fraction of the residues in the prediction with pLDDT less than 50
gene |
`STRING`
| The name of the gene if known, e.g. "COII"
gene |
`STRING`
| The name of the gene if known, e.g. "COII"
geneSynonyms |
`ARRAY<STRING>`
| Additional synonyms for the gene
geneSynonyms |
`ARRAY<STRING>`
| Additional synonyms for the gene
globalMetricValue |
`FLOAT64`
| The mean pLDDT of this prediction
isReferenceProteome |
`BOOL`
| Is this protein part of the reference proteome?
isReferenceProteome |
`BOOL`
| Is this protein part of the reference proteome?
isReviewed |
`BOOL`
| Has this protein been reviewed, i.e. is it part of SwissProt?
isReviewed |
`BOOL`
| Has this protein been reviewed, i.e. is it part of SwissProt?
globalMetricValue |
`FLOAT64`
| The mean pLDDT of this prediction
latestVersion |
`INT64`
| The latest AFDB version for this prediction
latestVersion |
`INT64`
| The latest AFDB version for this prediction
modelCreatedDate |
`DATE`
| The date of creation for this entry, e.g. "2022-06-01"
modelCreatedDate |
`DATE`
| The date of creation for this entry, e.g. "2022-06-01"
organismCommonNames |
`ARRAY<STRING>`
| List of common organism names
organismCommonNames |
`ARRAY<STRING>`
| List of common organism names
...
...
alphafold/data/pipeline.py
View file @
9b18d6a9
...
@@ -117,7 +117,7 @@ class DataPipeline:
...
@@ -117,7 +117,7 @@ class DataPipeline:
uniref90_database_path
:
str
,
uniref90_database_path
:
str
,
mgnify_database_path
:
str
,
mgnify_database_path
:
str
,
bfd_database_path
:
Optional
[
str
],
bfd_database_path
:
Optional
[
str
],
uni
clust
30_database_path
:
Optional
[
str
],
uni
ref
30_database_path
:
Optional
[
str
],
small_bfd_database_path
:
Optional
[
str
],
small_bfd_database_path
:
Optional
[
str
],
template_searcher
:
TemplateSearcher
,
template_searcher
:
TemplateSearcher
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
...
@@ -135,9 +135,9 @@ class DataPipeline:
...
@@ -135,9 +135,9 @@ class DataPipeline:
binary_path
=
jackhmmer_binary_path
,
binary_path
=
jackhmmer_binary_path
,
database_path
=
small_bfd_database_path
)
database_path
=
small_bfd_database_path
)
else
:
else
:
self
.
hhblits_bfd_uni
clust
_runner
=
hhblits
.
HHBlits
(
self
.
hhblits_bfd_uni
ref
_runner
=
hhblits
.
HHBlits
(
binary_path
=
hhblits_binary_path
,
binary_path
=
hhblits_binary_path
,
databases
=
[
bfd_database_path
,
uni
clust
30_database_path
])
databases
=
[
bfd_database_path
,
uni
ref
30_database_path
])
self
.
jackhmmer_mgnify_runner
=
jackhmmer
.
Jackhmmer
(
self
.
jackhmmer_mgnify_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
binary_path
=
jackhmmer_binary_path
,
database_path
=
mgnify_database_path
)
database_path
=
mgnify_database_path
)
...
@@ -211,14 +211,14 @@ class DataPipeline:
...
@@ -211,14 +211,14 @@ class DataPipeline:
use_precomputed_msas
=
self
.
use_precomputed_msas
)
use_precomputed_msas
=
self
.
use_precomputed_msas
)
bfd_msa
=
parsers
.
parse_stockholm
(
jackhmmer_small_bfd_result
[
'sto'
])
bfd_msa
=
parsers
.
parse_stockholm
(
jackhmmer_small_bfd_result
[
'sto'
])
else
:
else
:
bfd_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'bfd_uni
clust
_hits.a3m'
)
bfd_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'bfd_uni
ref
_hits.a3m'
)
hhblits_bfd_uni
clust
_result
=
run_msa_tool
(
hhblits_bfd_uni
ref
_result
=
run_msa_tool
(
msa_runner
=
self
.
hhblits_bfd_uni
clust
_runner
,
msa_runner
=
self
.
hhblits_bfd_uni
ref
_runner
,
input_fasta_path
=
input_fasta_path
,
input_fasta_path
=
input_fasta_path
,
msa_out_path
=
bfd_out_path
,
msa_out_path
=
bfd_out_path
,
msa_format
=
'a3m'
,
msa_format
=
'a3m'
,
use_precomputed_msas
=
self
.
use_precomputed_msas
)
use_precomputed_msas
=
self
.
use_precomputed_msas
)
bfd_msa
=
parsers
.
parse_a3m
(
hhblits_bfd_uni
clust
_result
[
'a3m'
])
bfd_msa
=
parsers
.
parse_a3m
(
hhblits_bfd_uni
ref
_result
[
'a3m'
])
templates_result
=
self
.
template_featurizer
.
get_templates
(
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_sequence
=
input_sequence
,
...
...
alphafold/model/common_modules.py
View file @
9b18d6a9
...
@@ -128,3 +128,64 @@ class Linear(hk.Module):
...
@@ -128,3 +128,64 @@ class Linear(hk.Module):
return
output
return
output
class
LayerNorm
(
hk
.
LayerNorm
):
"""LayerNorm module.
Equivalent to hk.LayerNorm but with different parameter shapes: they are
always vectors rather than possibly higher-rank tensors. This makes it easier
to change the layout whilst keep the model weight-compatible.
"""
def
__init__
(
self
,
axis
,
create_scale
:
bool
,
create_offset
:
bool
,
eps
:
float
=
1e-5
,
scale_init
=
None
,
offset_init
=
None
,
use_fast_variance
:
bool
=
False
,
name
=
None
,
param_axis
=
None
):
super
().
__init__
(
axis
=
axis
,
create_scale
=
False
,
create_offset
=
False
,
eps
=
eps
,
scale_init
=
None
,
offset_init
=
None
,
use_fast_variance
=
use_fast_variance
,
name
=
name
,
param_axis
=
param_axis
)
self
.
_temp_create_scale
=
create_scale
self
.
_temp_create_offset
=
create_offset
def
__call__
(
self
,
x
:
jnp
.
ndarray
)
->
jnp
.
ndarray
:
is_bf16
=
(
x
.
dtype
==
jnp
.
bfloat16
)
if
is_bf16
:
x
=
x
.
astype
(
jnp
.
float32
)
param_axis
=
self
.
param_axis
[
0
]
if
self
.
param_axis
else
-
1
param_shape
=
(
x
.
shape
[
param_axis
],)
param_broadcast_shape
=
[
1
]
*
x
.
ndim
param_broadcast_shape
[
param_axis
]
=
x
.
shape
[
param_axis
]
scale
=
None
offset
=
None
if
self
.
_temp_create_scale
:
scale
=
hk
.
get_parameter
(
'scale'
,
param_shape
,
x
.
dtype
,
init
=
self
.
scale_init
)
scale
=
scale
.
reshape
(
param_broadcast_shape
)
if
self
.
_temp_create_offset
:
offset
=
hk
.
get_parameter
(
'offset'
,
param_shape
,
x
.
dtype
,
init
=
self
.
offset_init
)
offset
=
offset
.
reshape
(
param_broadcast_shape
)
out
=
super
().
__call__
(
x
,
scale
=
scale
,
offset
=
offset
)
if
is_bf16
:
out
=
out
.
astype
(
jnp
.
bfloat16
)
return
out
\ No newline at end of file
alphafold/model/config.py
View file @
9b18d6a9
...
@@ -26,11 +26,11 @@ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
...
@@ -26,11 +26,11 @@ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
def
model_config
(
name
:
str
)
->
ml_collections
.
ConfigDict
:
def
model_config
(
name
:
str
)
->
ml_collections
.
ConfigDict
:
"""Get the ConfigDict of a CASP14 model."""
"""Get the ConfigDict of a CASP14 model."""
if
'multimer'
in
name
:
return
CONFIG_MULTIMER
if
name
not
in
CONFIG_DIFFS
:
if
name
not
in
CONFIG_DIFFS
:
raise
ValueError
(
f
'Invalid model name
{
name
}
.'
)
raise
ValueError
(
f
'Invalid model name
{
name
}
.'
)
if
'multimer'
in
name
:
cfg
=
copy
.
deepcopy
(
CONFIG_MULTIMER
)
else
:
cfg
=
copy
.
deepcopy
(
CONFIG
)
cfg
=
copy
.
deepcopy
(
CONFIG
)
cfg
.
update_from_flattened_dict
(
CONFIG_DIFFS
[
name
])
cfg
.
update_from_flattened_dict
(
CONFIG_DIFFS
[
name
])
return
cfg
return
cfg
...
@@ -52,11 +52,11 @@ MODEL_PRESETS = {
...
@@ -52,11 +52,11 @@ MODEL_PRESETS = {
'model_5_ptm'
,
'model_5_ptm'
,
),
),
'multimer'
:
(
'multimer'
:
(
'model_1_multimer_v
2
'
,
'model_1_multimer_v
3
'
,
'model_2_multimer_v
2
'
,
'model_2_multimer_v
3
'
,
'model_3_multimer_v
2
'
,
'model_3_multimer_v
3
'
,
'model_4_multimer_v
2
'
,
'model_4_multimer_v
3
'
,
'model_5_multimer_v
2
'
,
'model_5_multimer_v
3
'
,
),
),
}
}
MODEL_PRESETS
[
'monomer_casp14'
]
=
MODEL_PRESETS
[
'monomer'
]
MODEL_PRESETS
[
'monomer_casp14'
]
=
MODEL_PRESETS
[
'monomer'
]
...
@@ -118,8 +118,32 @@ CONFIG_DIFFS = {
...
@@ -118,8 +118,32 @@ CONFIG_DIFFS = {
},
},
'model_5_ptm'
:
{
'model_5_ptm'
:
{
'model.heads.predicted_aligned_error.weight'
:
0.1
'model.heads.predicted_aligned_error.weight'
:
0.1
}
},
'model_1_multimer_v3'
:
{},
'model_2_multimer_v3'
:
{},
'model_3_multimer_v3'
:
{},
'model_4_multimer_v3'
:
{
'model.embeddings_and_evoformer.num_extra_msa'
:
1152
},
'model_5_multimer_v3'
:
{
'model.embeddings_and_evoformer.num_extra_msa'
:
1152
},
}
}
# Key differences between multimer v1/v2 and v3, mostly due to numerical
# optimisations in the TriangleMultiplication module.
common_updates
=
{
'model.embeddings_and_evoformer.num_msa'
:
252
,
'model.embeddings_and_evoformer.num_extra_msa'
:
1152
,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming.fuse_projection_weights'
:
False
,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing.fuse_projection_weights'
:
False
,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_incoming.fuse_projection_weights'
:
False
,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_outgoing.fuse_projection_weights'
:
False
,
}
CONFIG_DIFFS
.
update
(
{
f
'model_
{
i
}
_multimer'
:
common_updates
for
i
in
range
(
1
,
6
)})
CONFIG_DIFFS
.
update
(
{
f
'model_
{
i
}
_multimer_v2'
:
common_updates
for
i
in
range
(
1
,
6
)})
CONFIG
=
ml_collections
.
ConfigDict
({
CONFIG
=
ml_collections
.
ConfigDict
({
'data'
:
{
'data'
:
{
...
@@ -260,14 +284,16 @@ CONFIG = ml_collections.ConfigDict({
...
@@ -260,14 +284,16 @@ CONFIG = ml_collections.ConfigDict({
'equation'
:
'ikc,jkc->ijc'
,
'equation'
:
'ikc,jkc->ijc'
,
'num_intermediate_channel'
:
128
,
'num_intermediate_channel'
:
128
,
'orientation'
:
'per_row'
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
False
,
},
},
'triangle_multiplication_incoming'
:
{
'triangle_multiplication_incoming'
:
{
'dropout_rate'
:
0.25
,
'dropout_rate'
:
0.25
,
'equation'
:
'kjc,kic->ijc'
,
'equation'
:
'kjc,kic->ijc'
,
'num_intermediate_channel'
:
128
,
'num_intermediate_channel'
:
128
,
'orientation'
:
'per_row'
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
False
,
},
},
'pair_transition'
:
{
'pair_transition'
:
{
'dropout_rate'
:
0.0
,
'dropout_rate'
:
0.0
,
...
@@ -328,14 +354,16 @@ CONFIG = ml_collections.ConfigDict({
...
@@ -328,14 +354,16 @@ CONFIG = ml_collections.ConfigDict({
'equation'
:
'ikc,jkc->ijc'
,
'equation'
:
'ikc,jkc->ijc'
,
'num_intermediate_channel'
:
64
,
'num_intermediate_channel'
:
64
,
'orientation'
:
'per_row'
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
False
,
},
},
'triangle_multiplication_incoming'
:
{
'triangle_multiplication_incoming'
:
{
'dropout_rate'
:
0.25
,
'dropout_rate'
:
0.25
,
'equation'
:
'kjc,kic->ijc'
,
'equation'
:
'kjc,kic->ijc'
,
'num_intermediate_channel'
:
64
,
'num_intermediate_channel'
:
64
,
'orientation'
:
'per_row'
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
False
,
},
},
'pair_transition'
:
{
'pair_transition'
:
{
'dropout_rate'
:
0.0
,
'dropout_rate'
:
0.0
,
...
@@ -354,7 +382,7 @@ CONFIG = ml_collections.ConfigDict({
...
@@ -354,7 +382,7 @@ CONFIG = ml_collections.ConfigDict({
'multimer_mode'
:
False
,
'multimer_mode'
:
False
,
'subbatch_size'
:
4
,
'subbatch_size'
:
4
,
'use_remat'
:
False
,
'use_remat'
:
False
,
'zero_init'
:
True
'zero_init'
:
True
,
},
},
'heads'
:
{
'heads'
:
{
'distogram'
:
{
'distogram'
:
{
...
@@ -483,27 +511,29 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
...
@@ -483,27 +511,29 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
'gating'
:
True
,
'gating'
:
True
,
'num_head'
:
4
,
'num_head'
:
4
,
'orientation'
:
'per_row'
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
},
},
'triangle_multiplication_incoming'
:
{
'triangle_multiplication_incoming'
:
{
'dropout_rate'
:
0.25
,
'dropout_rate'
:
0.25
,
'equation'
:
'kjc,kic->ijc'
,
'equation'
:
'kjc,kic->ijc'
,
'num_intermediate_channel'
:
128
,
'num_intermediate_channel'
:
128
,
'orientation'
:
'per_row'
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
True
,
},
},
'triangle_multiplication_outgoing'
:
{
'triangle_multiplication_outgoing'
:
{
'dropout_rate'
:
0.25
,
'dropout_rate'
:
0.25
,
'equation'
:
'ikc,jkc->ijc'
,
'equation'
:
'ikc,jkc->ijc'
,
'num_intermediate_channel'
:
128
,
'num_intermediate_channel'
:
128
,
'orientation'
:
'per_row'
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
True
,
}
}
},
},
'extra_msa_channel'
:
64
,
'extra_msa_channel'
:
64
,
'extra_msa_stack_num_block'
:
4
,
'extra_msa_stack_num_block'
:
4
,
'num_msa'
:
252
,
'num_msa'
:
508
,
'num_extra_msa'
:
1152
,
'num_extra_msa'
:
2048
,
'masked_msa'
:
{
'masked_msa'
:
{
'profile_prob'
:
0.1
,
'profile_prob'
:
0.1
,
'replace_fraction'
:
0.15
,
'replace_fraction'
:
0.15
,
...
@@ -564,24 +594,28 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
...
@@ -564,24 +594,28 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
'equation'
:
'kjc,kic->ijc'
,
'equation'
:
'kjc,kic->ijc'
,
'num_intermediate_channel'
:
64
,
'num_intermediate_channel'
:
64
,
'orientation'
:
'per_row'
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
True
,
},
},
'triangle_multiplication_outgoing'
:
{
'triangle_multiplication_outgoing'
:
{
'dropout_rate'
:
0.25
,
'dropout_rate'
:
0.25
,
'equation'
:
'ikc,jkc->ijc'
,
'equation'
:
'ikc,jkc->ijc'
,
'num_intermediate_channel'
:
64
,
'num_intermediate_channel'
:
64
,
'orientation'
:
'per_row'
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
'shared_dropout'
:
True
,
'fuse_projection_weights'
:
True
,
}
}
}
}
},
},
},
},
'global_config'
:
{
'global_config'
:
{
'bfloat16'
:
True
,
'bfloat16_output'
:
False
,
'deterministic'
:
False
,
'deterministic'
:
False
,
'multimer_mode'
:
True
,
'multimer_mode'
:
True
,
'subbatch_size'
:
4
,
'subbatch_size'
:
4
,
'use_remat'
:
False
,
'use_remat'
:
False
,
'zero_init'
:
True
'zero_init'
:
True
,
},
},
'heads'
:
{
'heads'
:
{
'distogram'
:
{
'distogram'
:
{
...
@@ -651,7 +685,13 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
...
@@ -651,7 +685,13 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
}
}
},
},
'num_ensemble_eval'
:
1
,
'num_ensemble_eval'
:
1
,
'num_recycle'
:
3
,
'num_recycle'
:
20
,
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `num_recycle` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
'recycle_early_stop_tolerance'
:
0.5
,
'resample_msa_in_recycling'
:
True
'resample_msa_in_recycling'
:
True
}
}
})
})
alphafold/model/folding.py
View file @
9b18d6a9
...
@@ -331,7 +331,7 @@ class FoldIteration(hk.Module):
...
@@ -331,7 +331,7 @@ class FoldIteration(hk.Module):
safe_key
,
*
sub_keys
=
safe_key
.
split
(
3
)
safe_key
,
*
sub_keys
=
safe_key
.
split
(
3
)
sub_keys
=
iter
(
sub_keys
)
sub_keys
=
iter
(
sub_keys
)
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
axis
=
[
-
1
],
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
...
@@ -353,7 +353,7 @@ class FoldIteration(hk.Module):
...
@@ -353,7 +353,7 @@ class FoldIteration(hk.Module):
act
=
jax
.
nn
.
relu
(
act
)
act
=
jax
.
nn
.
relu
(
act
)
act
+=
input_act
act
+=
input_act
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
axis
=
[
-
1
],
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
...
@@ -410,7 +410,7 @@ def generate_affines(representations, batch, config, global_config,
...
@@ -410,7 +410,7 @@ def generate_affines(representations, batch, config, global_config,
c
=
config
c
=
config
sequence_mask
=
batch
[
'seq_mask'
][:,
None
]
sequence_mask
=
batch
[
'seq_mask'
][:,
None
]
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
axis
=
[
-
1
],
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
...
@@ -433,7 +433,7 @@ def generate_affines(representations, batch, config, global_config,
...
@@ -433,7 +433,7 @@ def generate_affines(representations, batch, config, global_config,
'affine'
:
affine
.
to_tensor
(),
'affine'
:
affine
.
to_tensor
(),
}
}
act_2d
=
hk
.
LayerNorm
(
act_2d
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
axis
=
[
-
1
],
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
...
...
alphafold/model/folding_multimer.py
View file @
9b18d6a9
...
@@ -427,7 +427,7 @@ class FoldIteration(hk.Module):
...
@@ -427,7 +427,7 @@ class FoldIteration(hk.Module):
safe_key
,
*
sub_keys
=
safe_key
.
split
(
3
)
safe_key
,
*
sub_keys
=
safe_key
.
split
(
3
)
sub_keys
=
iter
(
sub_keys
)
sub_keys
=
iter
(
sub_keys
)
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=-
1
,
axis
=-
1
,
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
...
@@ -448,7 +448,7 @@ class FoldIteration(hk.Module):
...
@@ -448,7 +448,7 @@ class FoldIteration(hk.Module):
act
=
jax
.
nn
.
relu
(
act
)
act
=
jax
.
nn
.
relu
(
act
)
act
+=
input_act
act
+=
input_act
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=-
1
,
axis
=-
1
,
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
...
@@ -500,7 +500,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
...
@@ -500,7 +500,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
"""
"""
c
=
config
c
=
config
sequence_mask
=
batch
[
'seq_mask'
][:,
None
]
sequence_mask
=
batch
[
'seq_mask'
][:,
None
]
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=-
1
,
create_scale
=
True
,
create_offset
=
True
,
name
=
'single_layer_norm'
)(
axis
=-
1
,
create_scale
=
True
,
create_offset
=
True
,
name
=
'single_layer_norm'
)(
representations
[
'single'
])
representations
[
'single'
])
...
@@ -523,7 +523,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
...
@@ -523,7 +523,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
rigid
rigid
}
}
act_2d
=
hk
.
LayerNorm
(
act_2d
=
common_modules
.
LayerNorm
(
axis
=-
1
,
axis
=-
1
,
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
...
...
alphafold/model/geometry/struct_of_array.py
View file @
9b18d6a9
...
@@ -133,7 +133,7 @@ def flatten(instance):
...
@@ -133,7 +133,7 @@ def flatten(instance):
inner_treedefs
=
[]
inner_treedefs
=
[]
num_arrays
=
[]
num_arrays
=
[]
for
array_like
in
array_likes
:
for
array_like
in
array_likes
:
flat_array_like
,
inner_treedef
=
jax
.
tree_flatten
(
array_like
)
flat_array_like
,
inner_treedef
=
jax
.
tree_
util
.
tree_
flatten
(
array_like
)
inner_treedefs
.
append
(
inner_treedef
)
inner_treedefs
.
append
(
inner_treedef
)
flat_array_likes
+=
flat_array_like
flat_array_likes
+=
flat_array_like
num_arrays
.
append
(
len
(
flat_array_like
))
num_arrays
.
append
(
len
(
flat_array_like
))
...
@@ -206,7 +206,7 @@ class StructOfArray:
...
@@ -206,7 +206,7 @@ class StructOfArray:
for
num_array
,
inner_treedef
,
array_field
in
zip
(
num_arrays
,
for
num_array
,
inner_treedef
,
array_field
in
zip
(
num_arrays
,
inner_treedefs
,
inner_treedefs
,
array_fields
):
array_fields
):
value_dict
[
array_field
]
=
jax
.
tree_unflatten
(
value_dict
[
array_field
]
=
jax
.
tree_
util
.
tree_
unflatten
(
inner_treedef
,
data
[
array_start
:
array_start
+
num_array
])
inner_treedef
,
data
[
array_start
:
array_start
+
num_array
])
array_start
+=
num_array
array_start
+=
num_array
metadata_fields
=
get_metadata_fields
(
new_cls
)
metadata_fields
=
get_metadata_fields
(
new_cls
)
...
...
alphafold/model/mapping.py
View file @
9b18d6a9
...
@@ -47,11 +47,11 @@ def _maybe_get_size(array, axis):
...
@@ -47,11 +47,11 @@ def _maybe_get_size(array, axis):
def
_expand_axes
(
axes
,
values
,
name
=
'sharded_apply'
):
def
_expand_axes
(
axes
,
values
,
name
=
'sharded_apply'
):
values_tree_def
=
jax
.
tree_flatten
(
values
)[
1
]
values_tree_def
=
jax
.
tree_
util
.
tree_
flatten
(
values
)[
1
]
flat_axes
=
jax
.
api_util
.
flatten_axes
(
name
,
values_tree_def
,
axes
)
flat_axes
=
jax
.
api_util
.
flatten_axes
(
name
,
values_tree_def
,
axes
)
# Replace None's with PROXY
# Replace None's with PROXY
flat_axes
=
[
PROXY
if
x
is
None
else
x
for
x
in
flat_axes
]
flat_axes
=
[
PROXY
if
x
is
None
else
x
for
x
in
flat_axes
]
return
jax
.
tree_unflatten
(
values_tree_def
,
flat_axes
)
return
jax
.
tree_
util
.
tree_
unflatten
(
values_tree_def
,
flat_axes
)
def
sharded_map
(
def
sharded_map
(
...
@@ -126,7 +126,7 @@ def sharded_apply(
...
@@ -126,7 +126,7 @@ def sharded_apply(
in_axes_
=
_expand_axes
(
in_axes
,
args
)
in_axes_
=
_expand_axes
(
in_axes
,
args
)
in_sizes
=
jax
.
tree_map
(
_maybe_get_size
,
args
,
in_axes_
)
in_sizes
=
jax
.
tree_map
(
_maybe_get_size
,
args
,
in_axes_
)
flat_sizes
=
jax
.
tree_flatten
(
in_sizes
)[
0
]
flat_sizes
=
jax
.
tree_
util
.
tree_
flatten
(
in_sizes
)[
0
]
in_size
=
max
(
flat_sizes
)
in_size
=
max
(
flat_sizes
)
assert
all
(
i
in
{
in_size
,
-
1
}
for
i
in
flat_sizes
)
assert
all
(
i
in
{
in_size
,
-
1
}
for
i
in
flat_sizes
)
...
...
alphafold/model/modules.py
View file @
9b18d6a9
...
@@ -501,7 +501,7 @@ class Transition(hk.Module):
...
@@ -501,7 +501,7 @@ class Transition(hk.Module):
num_intermediate
=
int
(
nc
*
self
.
config
.
num_intermediate_factor
)
num_intermediate
=
int
(
nc
*
self
.
config
.
num_intermediate_factor
)
mask
=
jnp
.
expand_dims
(
mask
,
axis
=-
1
)
mask
=
jnp
.
expand_dims
(
mask
,
axis
=-
1
)
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
axis
=
[
-
1
],
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
...
@@ -569,12 +569,15 @@ class Attention(hk.Module):
...
@@ -569,12 +569,15 @@ class Attention(hk.Module):
q_weights
=
hk
.
get_parameter
(
q_weights
=
hk
.
get_parameter
(
'query_w'
,
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
key_dim
),
'query_w'
,
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
key_dim
),
dtype
=
q_data
.
dtype
,
init
=
glorot_uniform
())
init
=
glorot_uniform
())
k_weights
=
hk
.
get_parameter
(
k_weights
=
hk
.
get_parameter
(
'key_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
num_head
,
key_dim
),
'key_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
num_head
,
key_dim
),
dtype
=
q_data
.
dtype
,
init
=
glorot_uniform
())
init
=
glorot_uniform
())
v_weights
=
hk
.
get_parameter
(
v_weights
=
hk
.
get_parameter
(
'value_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
num_head
,
value_dim
),
'value_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
num_head
,
value_dim
),
dtype
=
q_data
.
dtype
,
init
=
glorot_uniform
())
init
=
glorot_uniform
())
q
=
jnp
.
einsum
(
'bqa,ahc->bqhc'
,
q_data
,
q_weights
)
*
key_dim
**
(
-
0.5
)
q
=
jnp
.
einsum
(
'bqa,ahc->bqhc'
,
q_data
,
q_weights
)
*
key_dim
**
(
-
0.5
)
...
@@ -595,10 +598,12 @@ class Attention(hk.Module):
...
@@ -595,10 +598,12 @@ class Attention(hk.Module):
gating_weights
=
hk
.
get_parameter
(
gating_weights
=
hk
.
get_parameter
(
'gating_w'
,
'gating_w'
,
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
value_dim
),
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
value_dim
),
dtype
=
q_data
.
dtype
,
init
=
hk
.
initializers
.
Constant
(
0.0
))
init
=
hk
.
initializers
.
Constant
(
0.0
))
gating_bias
=
hk
.
get_parameter
(
gating_bias
=
hk
.
get_parameter
(
'gating_b'
,
'gating_b'
,
shape
=
(
num_head
,
value_dim
),
shape
=
(
num_head
,
value_dim
),
dtype
=
q_data
.
dtype
,
init
=
hk
.
initializers
.
Constant
(
1.0
))
init
=
hk
.
initializers
.
Constant
(
1.0
))
gate_values
=
jnp
.
einsum
(
'bqc, chv->bqhv'
,
q_data
,
gate_values
=
jnp
.
einsum
(
'bqc, chv->bqhv'
,
q_data
,
...
@@ -610,8 +615,11 @@ class Attention(hk.Module):
...
@@ -610,8 +615,11 @@ class Attention(hk.Module):
o_weights
=
hk
.
get_parameter
(
o_weights
=
hk
.
get_parameter
(
'output_w'
,
shape
=
(
num_head
,
value_dim
,
self
.
output_dim
),
'output_w'
,
shape
=
(
num_head
,
value_dim
,
self
.
output_dim
),
dtype
=
q_data
.
dtype
,
init
=
init
)
init
=
init
)
o_bias
=
hk
.
get_parameter
(
'output_b'
,
shape
=
(
self
.
output_dim
,),
o_bias
=
hk
.
get_parameter
(
'output_b'
,
shape
=
(
self
.
output_dim
,),
dtype
=
q_data
.
dtype
,
init
=
hk
.
initializers
.
Constant
(
0.0
))
init
=
hk
.
initializers
.
Constant
(
0.0
))
output
=
jnp
.
einsum
(
'bqhc,hco->bqo'
,
weighted_avg
,
o_weights
)
+
o_bias
output
=
jnp
.
einsum
(
'bqhc,hco->bqo'
,
weighted_avg
,
o_weights
)
+
o_bias
...
@@ -658,12 +666,15 @@ class GlobalAttention(hk.Module):
...
@@ -658,12 +666,15 @@ class GlobalAttention(hk.Module):
q_weights
=
hk
.
get_parameter
(
q_weights
=
hk
.
get_parameter
(
'query_w'
,
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
key_dim
),
'query_w'
,
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
key_dim
),
dtype
=
q_data
.
dtype
,
init
=
glorot_uniform
())
init
=
glorot_uniform
())
k_weights
=
hk
.
get_parameter
(
k_weights
=
hk
.
get_parameter
(
'key_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
key_dim
),
'key_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
key_dim
),
dtype
=
q_data
.
dtype
,
init
=
glorot_uniform
())
init
=
glorot_uniform
())
v_weights
=
hk
.
get_parameter
(
v_weights
=
hk
.
get_parameter
(
'value_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
value_dim
),
'value_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
value_dim
),
dtype
=
q_data
.
dtype
,
init
=
glorot_uniform
())
init
=
glorot_uniform
())
v
=
jnp
.
einsum
(
'bka,ac->bkc'
,
m_data
,
v_weights
)
v
=
jnp
.
einsum
(
'bka,ac->bkc'
,
m_data
,
v_weights
)
...
@@ -684,18 +695,23 @@ class GlobalAttention(hk.Module):
...
@@ -684,18 +695,23 @@ class GlobalAttention(hk.Module):
o_weights
=
hk
.
get_parameter
(
o_weights
=
hk
.
get_parameter
(
'output_w'
,
shape
=
(
num_head
,
value_dim
,
self
.
output_dim
),
'output_w'
,
shape
=
(
num_head
,
value_dim
,
self
.
output_dim
),
dtype
=
q_data
.
dtype
,
init
=
init
)
init
=
init
)
o_bias
=
hk
.
get_parameter
(
'output_b'
,
shape
=
(
self
.
output_dim
,),
o_bias
=
hk
.
get_parameter
(
'output_b'
,
shape
=
(
self
.
output_dim
,),
dtype
=
q_data
.
dtype
,
init
=
hk
.
initializers
.
Constant
(
0.0
))
init
=
hk
.
initializers
.
Constant
(
0.0
))
if
self
.
config
.
gating
:
if
self
.
config
.
gating
:
gating_weights
=
hk
.
get_parameter
(
gating_weights
=
hk
.
get_parameter
(
'gating_w'
,
'gating_w'
,
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
value_dim
),
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
value_dim
),
dtype
=
q_data
.
dtype
,
init
=
hk
.
initializers
.
Constant
(
0.0
))
init
=
hk
.
initializers
.
Constant
(
0.0
))
gating_bias
=
hk
.
get_parameter
(
gating_bias
=
hk
.
get_parameter
(
'gating_b'
,
'gating_b'
,
shape
=
(
num_head
,
value_dim
),
shape
=
(
num_head
,
value_dim
),
dtype
=
q_data
.
dtype
,
init
=
hk
.
initializers
.
Constant
(
1.0
))
init
=
hk
.
initializers
.
Constant
(
1.0
))
gate_values
=
jnp
.
einsum
(
'bqc, chv->bqhv'
,
q_data
,
gating_weights
)
gate_values
=
jnp
.
einsum
(
'bqc, chv->bqhv'
,
q_data
,
gating_weights
)
...
@@ -745,11 +761,11 @@ class MSARowAttentionWithPairBias(hk.Module):
...
@@ -745,11 +761,11 @@ class MSARowAttentionWithPairBias(hk.Module):
bias
=
(
1e9
*
(
msa_mask
-
1.
))[:,
None
,
None
,
:]
bias
=
(
1e9
*
(
msa_mask
-
1.
))[:,
None
,
None
,
:]
assert
len
(
bias
.
shape
)
==
4
assert
len
(
bias
.
shape
)
==
4
msa_act
=
hk
.
LayerNorm
(
msa_act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
msa_act
)
msa_act
)
pair_act
=
hk
.
LayerNorm
(
pair_act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
axis
=
[
-
1
],
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
...
@@ -760,6 +776,7 @@ class MSARowAttentionWithPairBias(hk.Module):
...
@@ -760,6 +776,7 @@ class MSARowAttentionWithPairBias(hk.Module):
weights
=
hk
.
get_parameter
(
weights
=
hk
.
get_parameter
(
'feat_2d_weights'
,
'feat_2d_weights'
,
shape
=
(
pair_act
.
shape
[
-
1
],
c
.
num_head
),
shape
=
(
pair_act
.
shape
[
-
1
],
c
.
num_head
),
dtype
=
msa_act
.
dtype
,
init
=
hk
.
initializers
.
RandomNormal
(
stddev
=
init_factor
))
init
=
hk
.
initializers
.
RandomNormal
(
stddev
=
init_factor
))
nonbatched_bias
=
jnp
.
einsum
(
'qkc,ch->hqk'
,
pair_act
,
weights
)
nonbatched_bias
=
jnp
.
einsum
(
'qkc,ch->hqk'
,
pair_act
,
weights
)
...
@@ -812,7 +829,7 @@ class MSAColumnAttention(hk.Module):
...
@@ -812,7 +829,7 @@ class MSAColumnAttention(hk.Module):
bias
=
(
1e9
*
(
msa_mask
-
1.
))[:,
None
,
None
,
:]
bias
=
(
1e9
*
(
msa_mask
-
1.
))[:,
None
,
None
,
:]
assert
len
(
bias
.
shape
)
==
4
assert
len
(
bias
.
shape
)
==
4
msa_act
=
hk
.
LayerNorm
(
msa_act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
msa_act
)
msa_act
)
...
@@ -867,7 +884,7 @@ class MSAColumnGlobalAttention(hk.Module):
...
@@ -867,7 +884,7 @@ class MSAColumnGlobalAttention(hk.Module):
bias
=
(
1e9
*
(
msa_mask
-
1.
))[:,
None
,
None
,
:]
bias
=
(
1e9
*
(
msa_mask
-
1.
))[:,
None
,
None
,
:]
assert
len
(
bias
.
shape
)
==
4
assert
len
(
bias
.
shape
)
==
4
msa_act
=
hk
.
LayerNorm
(
msa_act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
msa_act
)
msa_act
)
...
@@ -924,7 +941,7 @@ class TriangleAttention(hk.Module):
...
@@ -924,7 +941,7 @@ class TriangleAttention(hk.Module):
bias
=
(
1e9
*
(
pair_mask
-
1.
))[:,
None
,
None
,
:]
bias
=
(
1e9
*
(
pair_mask
-
1.
))[:,
None
,
None
,
:]
assert
len
(
bias
.
shape
)
==
4
assert
len
(
bias
.
shape
)
==
4
pair_act
=
hk
.
LayerNorm
(
pair_act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
pair_act
)
pair_act
)
...
@@ -932,6 +949,7 @@ class TriangleAttention(hk.Module):
...
@@ -932,6 +949,7 @@ class TriangleAttention(hk.Module):
weights
=
hk
.
get_parameter
(
weights
=
hk
.
get_parameter
(
'feat_2d_weights'
,
'feat_2d_weights'
,
shape
=
(
pair_act
.
shape
[
-
1
],
c
.
num_head
),
shape
=
(
pair_act
.
shape
[
-
1
],
c
.
num_head
),
dtype
=
pair_act
.
dtype
,
init
=
hk
.
initializers
.
RandomNormal
(
stddev
=
init_factor
))
init
=
hk
.
initializers
.
RandomNormal
(
stddev
=
init_factor
))
nonbatched_bias
=
jnp
.
einsum
(
'qkc,ch->hqk'
,
pair_act
,
weights
)
nonbatched_bias
=
jnp
.
einsum
(
'qkc,ch->hqk'
,
pair_act
,
weights
)
...
@@ -1029,7 +1047,7 @@ class PredictedLDDTHead(hk.Module):
...
@@ -1029,7 +1047,7 @@ class PredictedLDDTHead(hk.Module):
"""
"""
act
=
representations
[
'structure_module'
]
act
=
representations
[
'structure_module'
]
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
axis
=
[
-
1
],
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
...
@@ -1251,6 +1269,19 @@ class ExperimentallyResolvedHead(hk.Module):
...
@@ -1251,6 +1269,19 @@ class ExperimentallyResolvedHead(hk.Module):
return
output
return
output
def
_layer_norm
(
axis
=-
1
,
name
=
'layer_norm'
):
return
common_modules
.
LayerNorm
(
axis
=
axis
,
create_scale
=
True
,
create_offset
=
True
,
eps
=
1e-5
,
use_fast_variance
=
True
,
scale_init
=
hk
.
initializers
.
Constant
(
1.
),
offset_init
=
hk
.
initializers
.
Constant
(
0.
),
param_axis
=
axis
,
name
=
name
)
class
TriangleMultiplication
(
hk
.
Module
):
class
TriangleMultiplication
(
hk
.
Module
):
"""Triangle multiplication layer ("outgoing" or "incoming").
"""Triangle multiplication layer ("outgoing" or "incoming").
...
@@ -1263,25 +1294,34 @@ class TriangleMultiplication(hk.Module):
...
@@ -1263,25 +1294,34 @@ class TriangleMultiplication(hk.Module):
self
.
config
=
config
self
.
config
=
config
self
.
global_config
=
global_config
self
.
global_config
=
global_config
def
__call__
(
self
,
act
,
mask
,
is_training
=
True
):
def
__call__
(
self
,
left_act
,
left_
mask
,
is_training
=
True
):
"""Builds TriangleMultiplication module.
"""Builds TriangleMultiplication module.
Arguments:
Arguments:
act: Pair activations, shape [N_res, N_res, c_z]
left_
act: Pair activations, shape [N_res, N_res, c_z]
mask: Pair mask, shape [N_res, N_res].
left_
mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
is_training: Whether the module is in training mode.
Returns:
Returns:
Outputs, same shape/type as act.
Outputs, same shape/type as
left_
act.
"""
"""
del
is_training
del
is_training
if
self
.
config
.
fuse_projection_weights
:
return
self
.
_fused_triangle_multiplication
(
left_act
,
left_mask
)
else
:
return
self
.
_triangle_multiplication
(
left_act
,
left_mask
)
@
hk
.
transparent
def
_triangle_multiplication
(
self
,
left_act
,
left_mask
):
"""Implementation of TriangleMultiplication used in AF2 and AF-M<2.3."""
c
=
self
.
config
c
=
self
.
config
gc
=
self
.
global_config
gc
=
self
.
global_config
mask
=
mask
[...,
None
]
mask
=
left_
mask
[...,
None
]
act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'layer_norm_input'
)(
act
)
name
=
'layer_norm_input'
)(
left_
act
)
input_act
=
act
input_act
=
act
left_projection
=
common_modules
.
Linear
(
left_projection
=
common_modules
.
Linear
(
...
@@ -1317,7 +1357,7 @@ class TriangleMultiplication(hk.Module):
...
@@ -1317,7 +1357,7 @@ class TriangleMultiplication(hk.Module):
# b = left_proj_act and a = right_proj_act
# b = left_proj_act and a = right_proj_act
act
=
jnp
.
einsum
(
c
.
equation
,
left_proj_act
,
right_proj_act
)
act
=
jnp
.
einsum
(
c
.
equation
,
left_proj_act
,
right_proj_act
)
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
axis
=
[
-
1
],
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
...
@@ -1340,6 +1380,50 @@ class TriangleMultiplication(hk.Module):
...
@@ -1340,6 +1380,50 @@ class TriangleMultiplication(hk.Module):
return
act
return
act
@
hk
.
transparent
def
_fused_triangle_multiplication
(
self
,
left_act
,
left_mask
):
"""TriangleMultiplication with fused projection weights."""
mask
=
left_mask
[...,
None
]
c
=
self
.
config
gc
=
self
.
global_config
left_act
=
_layer_norm
(
axis
=-
1
,
name
=
'left_norm_input'
)(
left_act
)
# Both left and right projections are fused into projection.
projection
=
common_modules
.
Linear
(
2
*
c
.
num_intermediate_channel
,
name
=
'projection'
)
proj_act
=
mask
*
projection
(
left_act
)
# Both left + right gate are fused into gate_values.
gate_values
=
common_modules
.
Linear
(
2
*
c
.
num_intermediate_channel
,
name
=
'gate'
,
bias_init
=
1.
,
initializer
=
utils
.
final_init
(
gc
))(
left_act
)
proj_act
*=
jax
.
nn
.
sigmoid
(
gate_values
)
left_proj_act
=
proj_act
[:,
:,
:
c
.
num_intermediate_channel
]
right_proj_act
=
proj_act
[:,
:,
c
.
num_intermediate_channel
:]
act
=
jnp
.
einsum
(
c
.
equation
,
left_proj_act
,
right_proj_act
)
act
=
_layer_norm
(
axis
=-
1
,
name
=
'center_norm'
)(
act
)
output_channel
=
int
(
left_act
.
shape
[
-
1
])
act
=
common_modules
.
Linear
(
output_channel
,
initializer
=
utils
.
final_init
(
gc
),
name
=
'output_projection'
)(
act
)
gate_values
=
common_modules
.
Linear
(
output_channel
,
bias_init
=
1.
,
initializer
=
utils
.
final_init
(
gc
),
name
=
'gating_linear'
)(
left_act
)
act
*=
jax
.
nn
.
sigmoid
(
gate_values
)
return
act
class
DistogramHead
(
hk
.
Module
):
class
DistogramHead
(
hk
.
Module
):
"""Head to predict a distogram.
"""Head to predict a distogram.
...
@@ -1446,7 +1530,7 @@ class OuterProductMean(hk.Module):
...
@@ -1446,7 +1530,7 @@ class OuterProductMean(hk.Module):
c
=
self
.
config
c
=
self
.
config
mask
=
mask
[...,
None
]
mask
=
mask
[...,
None
]
act
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
'layer_norm_input'
)(
act
)
act
=
common_modules
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
'layer_norm_input'
)(
act
)
left_act
=
mask
*
common_modules
.
Linear
(
left_act
=
mask
*
common_modules
.
Linear
(
c
.
num_outer_channel
,
c
.
num_outer_channel
,
...
@@ -1469,9 +1553,11 @@ class OuterProductMean(hk.Module):
...
@@ -1469,9 +1553,11 @@ class OuterProductMean(hk.Module):
'output_w'
,
'output_w'
,
shape
=
(
c
.
num_outer_channel
,
c
.
num_outer_channel
,
shape
=
(
c
.
num_outer_channel
,
c
.
num_outer_channel
,
self
.
num_output_channel
),
self
.
num_output_channel
),
dtype
=
act
.
dtype
,
init
=
init_w
)
init
=
init_w
)
output_b
=
hk
.
get_parameter
(
output_b
=
hk
.
get_parameter
(
'output_b'
,
shape
=
(
self
.
num_output_channel
,),
'output_b'
,
shape
=
(
self
.
num_output_channel
,),
dtype
=
act
.
dtype
,
init
=
hk
.
initializers
.
Constant
(
0.0
))
init
=
hk
.
initializers
.
Constant
(
0.0
))
def
compute_chunk
(
left_act
):
def
compute_chunk
(
left_act
):
...
@@ -1738,7 +1824,7 @@ class EmbeddingsAndEvoformer(hk.Module):
...
@@ -1738,7 +1824,7 @@ class EmbeddingsAndEvoformer(hk.Module):
dgram
)
dgram
)
if
c
.
recycle_features
:
if
c
.
recycle_features
:
prev_msa_first_row
=
hk
.
LayerNorm
(
prev_msa_first_row
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
axis
=
[
-
1
],
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
...
@@ -1746,7 +1832,7 @@ class EmbeddingsAndEvoformer(hk.Module):
...
@@ -1746,7 +1832,7 @@ class EmbeddingsAndEvoformer(hk.Module):
batch
[
'prev_msa_first_row'
])
batch
[
'prev_msa_first_row'
])
msa_activations
=
msa_activations
.
at
[
0
].
add
(
prev_msa_first_row
)
msa_activations
=
msa_activations
.
at
[
0
].
add
(
prev_msa_first_row
)
pair_activations
+=
hk
.
LayerNorm
(
pair_activations
+=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
axis
=
[
-
1
],
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
...
@@ -2020,7 +2106,7 @@ class SingleTemplateEmbedding(hk.Module):
...
@@ -2020,7 +2106,7 @@ class SingleTemplateEmbedding(hk.Module):
self
.
config
.
template_pair_stack
,
self
.
global_config
)(
self
.
config
.
template_pair_stack
,
self
.
global_config
)(
act
,
mask_2d
,
is_training
)
act
,
mask_2d
,
is_training
)
act
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
'output_layer_norm'
)(
act
)
act
=
common_modules
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
'output_layer_norm'
)(
act
)
return
act
return
act
...
...
alphafold/model/modules_multimer.py
View file @
9b18d6a9
...
@@ -475,20 +475,51 @@ class AlphaFold(hk.Module):
...
@@ -475,20 +475,51 @@ class AlphaFold(hk.Module):
# Eval mode or tests: use the maximum number of iterations.
# Eval mode or tests: use the maximum number of iterations.
num_iter
=
c
.
num_recycle
num_iter
=
c
.
num_recycle
def
recycle_body
(
i
,
x
):
def
distances
(
points
):
del
i
"""Compute all pairwise distances for a set of points."""
prev
,
safe_key
=
x
return
jnp
.
sqrt
(
jnp
.
sum
((
points
[:,
None
]
-
points
[
None
,
:])
**
2
,
axis
=-
1
))
def
recycle_body
(
x
):
i
,
_
,
prev
,
safe_key
=
x
safe_key1
,
safe_key2
=
safe_key
.
split
()
if
c
.
resample_msa_in_recycling
else
safe_key
.
duplicate
()
# pylint: disable=line-too-long
safe_key1
,
safe_key2
=
safe_key
.
split
()
if
c
.
resample_msa_in_recycling
else
safe_key
.
duplicate
()
# pylint: disable=line-too-long
ret
=
apply_network
(
prev
=
prev
,
safe_key
=
safe_key2
)
ret
=
apply_network
(
prev
=
prev
,
safe_key
=
safe_key2
)
return
get_prev
(
ret
),
safe_key1
return
i
+
1
,
prev
,
get_prev
(
ret
),
safe_key1
prev
,
safe_key
=
hk
.
fori_loop
(
0
,
num_iter
,
recycle_body
,
(
prev
,
safe_key
))
def
recycle_cond
(
x
):
i
,
prev
,
next_in
,
_
=
x
ca_idx
=
residue_constants
.
atom_order
[
'CA'
]
sq_diff
=
jnp
.
square
(
distances
(
prev
[
'prev_pos'
][:,
ca_idx
,
:])
-
distances
(
next_in
[
'prev_pos'
][:,
ca_idx
,
:]))
mask
=
batch
[
'seq_mask'
][:,
None
]
*
batch
[
'seq_mask'
][
None
,
:]
sq_diff
=
utils
.
mask_mean
(
mask
,
sq_diff
)
# Early stopping criteria based on criteria used in
# AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
diff
=
jnp
.
sqrt
(
sq_diff
+
1e-8
)
# avoid bad numerics giving negatives
less_than_max_recycles
=
(
i
<
num_iter
)
has_exceeded_tolerance
=
(
(
i
==
0
)
|
(
diff
>
c
.
recycle_early_stop_tolerance
))
return
less_than_max_recycles
&
has_exceeded_tolerance
if
hk
.
running_init
():
num_recycles
,
_
,
prev
,
safe_key
=
recycle_body
(
(
0
,
prev
,
prev
,
safe_key
))
else
:
num_recycles
,
_
,
prev
,
safe_key
=
hk
.
while_loop
(
recycle_cond
,
recycle_body
,
(
0
,
prev
,
prev
,
safe_key
))
else
:
# No recycling.
num_recycles
=
0
# Run extra iteration.
# Run extra iteration.
ret
=
apply_network
(
prev
=
prev
,
safe_key
=
safe_key
)
ret
=
apply_network
(
prev
=
prev
,
safe_key
=
safe_key
)
if
not
return_representations
:
if
not
return_representations
:
del
ret
[
'representations'
]
del
ret
[
'representations'
]
ret
[
'num_recycles'
]
=
num_recycles
return
ret
return
ret
...
@@ -524,11 +555,13 @@ class EmbeddingsAndEvoformer(hk.Module):
...
@@ -524,11 +555,13 @@ class EmbeddingsAndEvoformer(hk.Module):
Feature embedding using the features as described before.
Feature embedding using the features as described before.
"""
"""
c
=
self
.
config
c
=
self
.
config
gc
=
self
.
global_config
rel_feats
=
[]
rel_feats
=
[]
pos
=
batch
[
'residue_index'
]
pos
=
batch
[
'residue_index'
]
asym_id
=
batch
[
'asym_id'
]
asym_id
=
batch
[
'asym_id'
]
asym_id_same
=
jnp
.
equal
(
asym_id
[:,
None
],
asym_id
[
None
,
:])
asym_id_same
=
jnp
.
equal
(
asym_id
[:,
None
],
asym_id
[
None
,
:])
offset
=
pos
[:,
None
]
-
pos
[
None
,
:]
offset
=
pos
[:,
None
]
-
pos
[
None
,
:]
dtype
=
jnp
.
bfloat16
if
gc
.
bfloat16
else
jnp
.
float32
clipped_offset
=
jnp
.
clip
(
clipped_offset
=
jnp
.
clip
(
offset
+
c
.
max_relative_idx
,
a_min
=
0
,
a_max
=
2
*
c
.
max_relative_idx
)
offset
+
c
.
max_relative_idx
,
a_min
=
0
,
a_max
=
2
*
c
.
max_relative_idx
)
...
@@ -568,6 +601,7 @@ class EmbeddingsAndEvoformer(hk.Module):
...
@@ -568,6 +601,7 @@ class EmbeddingsAndEvoformer(hk.Module):
rel_feat
=
jnp
.
concatenate
(
rel_feats
,
axis
=-
1
)
rel_feat
=
jnp
.
concatenate
(
rel_feats
,
axis
=-
1
)
rel_feat
=
rel_feat
.
astype
(
dtype
)
return
common_modules
.
Linear
(
return
common_modules
.
Linear
(
c
.
pair_channel
,
c
.
pair_channel
,
name
=
'position_activations'
)(
name
=
'position_activations'
)(
...
@@ -579,6 +613,7 @@ class EmbeddingsAndEvoformer(hk.Module):
...
@@ -579,6 +613,7 @@ class EmbeddingsAndEvoformer(hk.Module):
gc
=
self
.
global_config
gc
=
self
.
global_config
batch
=
dict
(
batch
)
batch
=
dict
(
batch
)
dtype
=
jnp
.
bfloat16
if
gc
.
bfloat16
else
jnp
.
float32
if
safe_key
is
None
:
if
safe_key
is
None
:
safe_key
=
prng
.
SafeKey
(
hk
.
next_rng_key
())
safe_key
=
prng
.
SafeKey
(
hk
.
next_rng_key
())
...
@@ -587,7 +622,8 @@ class EmbeddingsAndEvoformer(hk.Module):
...
@@ -587,7 +622,8 @@ class EmbeddingsAndEvoformer(hk.Module):
batch
[
'msa_profile'
]
=
make_msa_profile
(
batch
)
batch
[
'msa_profile'
]
=
make_msa_profile
(
batch
)
target_feat
=
jax
.
nn
.
one_hot
(
batch
[
'aatype'
],
21
)
with
utils
.
bfloat16_context
():
target_feat
=
jax
.
nn
.
one_hot
(
batch
[
'aatype'
],
21
).
astype
(
dtype
)
preprocess_1d
=
common_modules
.
Linear
(
preprocess_1d
=
common_modules
.
Linear
(
c
.
msa_channel
,
name
=
'preprocess_1d'
)(
c
.
msa_channel
,
name
=
'preprocess_1d'
)(
...
@@ -600,12 +636,11 @@ class EmbeddingsAndEvoformer(hk.Module):
...
@@ -600,12 +636,11 @@ class EmbeddingsAndEvoformer(hk.Module):
(
batch
[
'cluster_profile'
],
(
batch
[
'cluster_profile'
],
batch
[
'cluster_deletion_mean'
])
=
nearest_neighbor_clusters
(
batch
)
batch
[
'cluster_deletion_mean'
])
=
nearest_neighbor_clusters
(
batch
)
msa_feat
=
create_msa_feat
(
batch
)
msa_feat
=
create_msa_feat
(
batch
)
.
astype
(
dtype
)
preprocess_msa
=
common_modules
.
Linear
(
preprocess_msa
=
common_modules
.
Linear
(
c
.
msa_channel
,
name
=
'preprocess_msa'
)(
c
.
msa_channel
,
name
=
'preprocess_msa'
)(
msa_feat
)
msa_feat
)
msa_activations
=
jnp
.
expand_dims
(
preprocess_1d
,
axis
=
0
)
+
preprocess_msa
msa_activations
=
jnp
.
expand_dims
(
preprocess_1d
,
axis
=
0
)
+
preprocess_msa
left_single
=
common_modules
.
Linear
(
left_single
=
common_modules
.
Linear
(
...
@@ -616,7 +651,7 @@ class EmbeddingsAndEvoformer(hk.Module):
...
@@ -616,7 +651,7 @@ class EmbeddingsAndEvoformer(hk.Module):
target_feat
)
target_feat
)
pair_activations
=
left_single
[:,
None
]
+
right_single
[
None
]
pair_activations
=
left_single
[:,
None
]
+
right_single
[
None
]
mask_2d
=
batch
[
'seq_mask'
][:,
None
]
*
batch
[
'seq_mask'
][
None
,
:]
mask_2d
=
batch
[
'seq_mask'
][:,
None
]
*
batch
[
'seq_mask'
][
None
,
:]
mask_2d
=
mask_2d
.
astype
(
jnp
.
float32
)
mask_2d
=
mask_2d
.
astype
(
dtype
)
if
c
.
recycle_pos
:
if
c
.
recycle_pos
:
prev_pseudo_beta
=
modules
.
pseudo_beta_fn
(
prev_pseudo_beta
=
modules
.
pseudo_beta_fn
(
...
@@ -624,25 +659,25 @@ class EmbeddingsAndEvoformer(hk.Module):
...
@@ -624,25 +659,25 @@ class EmbeddingsAndEvoformer(hk.Module):
dgram
=
modules
.
dgram_from_positions
(
dgram
=
modules
.
dgram_from_positions
(
prev_pseudo_beta
,
**
self
.
config
.
prev_pos
)
prev_pseudo_beta
,
**
self
.
config
.
prev_pos
)
dgram
=
dgram
.
astype
(
dtype
)
pair_activations
+=
common_modules
.
Linear
(
pair_activations
+=
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'prev_pos_linear'
)(
c
.
pair_channel
,
name
=
'prev_pos_linear'
)(
dgram
)
dgram
)
if
c
.
recycle_features
:
if
c
.
recycle_features
:
prev_msa_first_row
=
hk
.
LayerNorm
(
prev_msa_first_row
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
axis
=
[
-
1
],
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
name
=
'prev_msa_first_row_norm'
)(
name
=
'prev_msa_first_row_norm'
)(
batch
[
'prev_msa_first_row'
])
batch
[
'prev_msa_first_row'
])
.
astype
(
dtype
)
msa_activations
=
msa_activations
.
at
[
0
].
add
(
prev_msa_first_row
)
msa_activations
=
msa_activations
.
at
[
0
].
add
(
prev_msa_first_row
)
pair_activations
+=
hk
.
LayerNorm
(
pair_activations
+=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
axis
=
[
-
1
],
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
name
=
'prev_pair_norm'
)(
name
=
'prev_pair_norm'
)(
batch
[
'prev_pair'
])
batch
[
'prev_pair'
])
.
astype
(
dtype
)
if
c
.
max_relative_idx
:
if
c
.
max_relative_idx
:
pair_activations
+=
self
.
_relative_encoding
(
batch
)
pair_activations
+=
self
.
_relative_encoding
(
batch
)
...
@@ -673,8 +708,8 @@ class EmbeddingsAndEvoformer(hk.Module):
...
@@ -673,8 +708,8 @@ class EmbeddingsAndEvoformer(hk.Module):
extra_msa_activations
=
common_modules
.
Linear
(
extra_msa_activations
=
common_modules
.
Linear
(
c
.
extra_msa_channel
,
c
.
extra_msa_channel
,
name
=
'extra_msa_activations'
)(
name
=
'extra_msa_activations'
)(
extra_msa_feat
)
extra_msa_feat
)
.
astype
(
dtype
)
extra_msa_mask
=
extra_msa_mask
.
astype
(
jnp
.
float32
)
extra_msa_mask
=
extra_msa_mask
.
astype
(
dtype
)
extra_evoformer_input
=
{
extra_evoformer_input
=
{
'msa'
:
extra_msa_activations
,
'msa'
:
extra_msa_activations
,
...
@@ -714,18 +749,19 @@ class EmbeddingsAndEvoformer(hk.Module):
...
@@ -714,18 +749,19 @@ class EmbeddingsAndEvoformer(hk.Module):
'msa'
:
msa_activations
,
'msa'
:
msa_activations
,
'pair'
:
pair_activations
,
'pair'
:
pair_activations
,
}
}
evoformer_masks
=
{
'msa'
:
batch
[
'msa_mask'
].
astype
(
jnp
.
float32
),
evoformer_masks
=
{
'pair'
:
mask_2d
}
'msa'
:
batch
[
'msa_mask'
].
astype
(
dtype
),
'pair'
:
mask_2d
}
if
c
.
template
.
enabled
:
if
c
.
template
.
enabled
:
template_features
,
template_masks
=
(
template_features
,
template_masks
=
(
template_embedding_1d
(
batch
=
batch
,
num_channel
=
c
.
msa_channel
))
template_embedding_1d
(
batch
=
batch
,
num_channel
=
c
.
msa_channel
,
global_config
=
gc
))
evoformer_input
[
'msa'
]
=
jnp
.
concatenate
(
evoformer_input
[
'msa'
]
=
jnp
.
concatenate
(
[
evoformer_input
[
'msa'
],
template_features
],
axis
=
0
)
[
evoformer_input
[
'msa'
],
template_features
],
axis
=
0
)
evoformer_masks
[
'msa'
]
=
jnp
.
concatenate
(
evoformer_masks
[
'msa'
]
=
jnp
.
concatenate
(
[
evoformer_masks
[
'msa'
],
template_masks
],
axis
=
0
)
[
evoformer_masks
[
'msa'
],
template_masks
],
axis
=
0
)
evoformer_iteration
=
modules
.
EvoformerIteration
(
evoformer_iteration
=
modules
.
EvoformerIteration
(
c
.
evoformer
,
gc
,
is_extra_msa
=
False
,
name
=
'evoformer_iteration'
)
c
.
evoformer
,
gc
,
is_extra_msa
=
False
,
name
=
'evoformer_iteration'
)
...
@@ -771,6 +807,12 @@ class EmbeddingsAndEvoformer(hk.Module):
...
@@ -771,6 +807,12 @@ class EmbeddingsAndEvoformer(hk.Module):
msa_activations
[
0
],
msa_activations
[
0
],
})
})
# Convert back to float32 if we're not saving memory.
if
not
gc
.
bfloat16_output
:
for
k
,
v
in
output
.
items
():
if
v
.
dtype
==
jnp
.
bfloat16
:
output
[
k
]
=
v
.
astype
(
jnp
.
float32
)
return
output
return
output
...
@@ -917,6 +959,9 @@ class SingleTemplateEmbedding(hk.Module):
...
@@ -917,6 +959,9 @@ class SingleTemplateEmbedding(hk.Module):
# backbone affine - i.e. in each residues local frame, what direction are
# backbone affine - i.e. in each residues local frame, what direction are
# each of the other residues.
# each of the other residues.
raw_atom_pos
=
template_all_atom_positions
raw_atom_pos
=
template_all_atom_positions
if
gc
.
bfloat16
:
# Vec3Arrays are required to be float32
raw_atom_pos
=
raw_atom_pos
.
astype
(
jnp
.
float32
)
atom_pos
=
geometry
.
Vec3Array
.
from_array
(
raw_atom_pos
)
atom_pos
=
geometry
.
Vec3Array
.
from_array
(
raw_atom_pos
)
rigid
,
backbone_mask
=
folding_multimer
.
make_backbone_affine
(
rigid
,
backbone_mask
=
folding_multimer
.
make_backbone_affine
(
...
@@ -928,6 +973,10 @@ class SingleTemplateEmbedding(hk.Module):
...
@@ -928,6 +973,10 @@ class SingleTemplateEmbedding(hk.Module):
unit_vector
=
rigid_vec
.
normalized
()
unit_vector
=
rigid_vec
.
normalized
()
unit_vector
=
[
unit_vector
.
x
,
unit_vector
.
y
,
unit_vector
.
z
]
unit_vector
=
[
unit_vector
.
x
,
unit_vector
.
y
,
unit_vector
.
z
]
if
gc
.
bfloat16
:
unit_vector
=
[
x
.
astype
(
jnp
.
bfloat16
)
for
x
in
unit_vector
]
backbone_mask
=
backbone_mask
.
astype
(
jnp
.
bfloat16
)
backbone_mask_2d
=
backbone_mask
[:,
None
]
*
backbone_mask
[
None
,
:]
backbone_mask_2d
=
backbone_mask
[:,
None
]
*
backbone_mask
[
None
,
:]
backbone_mask_2d
*=
multichain_mask_2d
backbone_mask_2d
*=
multichain_mask_2d
unit_vector
=
[
x
*
backbone_mask_2d
for
x
in
unit_vector
]
unit_vector
=
[
x
*
backbone_mask_2d
for
x
in
unit_vector
]
...
@@ -937,7 +986,7 @@ class SingleTemplateEmbedding(hk.Module):
...
@@ -937,7 +986,7 @@ class SingleTemplateEmbedding(hk.Module):
to_concat
.
extend
([(
x
,
0
)
for
x
in
unit_vector
])
to_concat
.
extend
([(
x
,
0
)
for
x
in
unit_vector
])
to_concat
.
append
((
backbone_mask_2d
,
0
))
to_concat
.
append
((
backbone_mask_2d
,
0
))
query_embedding
=
hk
.
LayerNorm
(
query_embedding
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
axis
=
[
-
1
],
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
...
@@ -986,12 +1035,13 @@ class SingleTemplateEmbedding(hk.Module):
...
@@ -986,12 +1035,13 @@ class SingleTemplateEmbedding(hk.Module):
template_iteration_fn
)
template_iteration_fn
)
act
,
safe_key
=
template_stack
((
act
,
safe_subkey
))
act
,
safe_key
=
template_stack
((
act
,
safe_subkey
))
act
=
hk
.
LayerNorm
(
act
=
common_modules
.
LayerNorm
(
axis
=
[
-
1
],
axis
=
[
-
1
],
create_scale
=
True
,
create_scale
=
True
,
create_offset
=
True
,
create_offset
=
True
,
name
=
'output_layer_norm'
)(
name
=
'output_layer_norm'
)(
act
)
act
)
return
act
return
act
...
@@ -1044,21 +1094,18 @@ class TemplateEmbeddingIteration(hk.Module):
...
@@ -1044,21 +1094,18 @@ class TemplateEmbeddingIteration(hk.Module):
act
,
act
,
pair_mask
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
safe_key
=
next
(
sub_keys
))
act
=
dropout_wrapper_fn
(
act
=
dropout_wrapper_fn
(
modules
.
TriangleAttention
(
c
.
triangle_attention_starting_node
,
gc
,
modules
.
TriangleAttention
(
c
.
triangle_attention_starting_node
,
gc
,
name
=
'triangle_attention_starting_node'
),
name
=
'triangle_attention_starting_node'
),
act
,
act
,
pair_mask
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
safe_key
=
next
(
sub_keys
))
act
=
dropout_wrapper_fn
(
act
=
dropout_wrapper_fn
(
modules
.
TriangleAttention
(
c
.
triangle_attention_ending_node
,
gc
,
modules
.
TriangleAttention
(
c
.
triangle_attention_ending_node
,
gc
,
name
=
'triangle_attention_ending_node'
),
name
=
'triangle_attention_ending_node'
),
act
,
act
,
pair_mask
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
safe_key
=
next
(
sub_keys
))
act
=
dropout_wrapper_fn
(
act
=
dropout_wrapper_fn
(
modules
.
Transition
(
c
.
pair_transition
,
gc
,
modules
.
Transition
(
c
.
pair_transition
,
gc
,
name
=
'pair_transition'
),
name
=
'pair_transition'
),
...
@@ -1069,7 +1116,7 @@ class TemplateEmbeddingIteration(hk.Module):
...
@@ -1069,7 +1116,7 @@ class TemplateEmbeddingIteration(hk.Module):
return
act
return
act
def
template_embedding_1d
(
batch
,
num_channel
):
def
template_embedding_1d
(
batch
,
num_channel
,
global_config
):
"""Embed templates into an (num_res, num_templates, num_channels) embedding.
"""Embed templates into an (num_res, num_templates, num_channels) embedding.
Args:
Args:
...
@@ -1080,6 +1127,7 @@ def template_embedding_1d(batch, num_channel):
...
@@ -1080,6 +1127,7 @@ def template_embedding_1d(batch, num_channel):
template_all_atom_mask, (num_templates, num_residues, 37) atom mask for
template_all_atom_mask, (num_templates, num_residues, 37) atom mask for
each template.
each template.
num_channel: The number of channels in the output.
num_channel: The number of channels in the output.
global_config: The global_config.
Returns:
Returns:
An embedding of shape (num_templates, num_res, num_channels) and a mask of
An embedding of shape (num_templates, num_res, num_channels) and a mask of
...
@@ -1112,6 +1160,10 @@ def template_embedding_1d(batch, num_channel):
...
@@ -1112,6 +1160,10 @@ def template_embedding_1d(batch, num_channel):
template_mask
=
chi_mask
[:,
:,
0
]
template_mask
=
chi_mask
[:,
:,
0
]
if
global_config
.
bfloat16
:
template_features
=
template_features
.
astype
(
jnp
.
bfloat16
)
template_mask
=
template_mask
.
astype
(
jnp
.
bfloat16
)
template_activations
=
common_modules
.
Linear
(
template_activations
=
common_modules
.
Linear
(
num_channel
,
num_channel
,
initializer
=
'relu'
,
initializer
=
'relu'
,
...
...
alphafold/model/utils.py
View file @
9b18d6a9
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""A collection of JAX utility functions for use in protein folding."""
"""A collection of JAX utility functions for use in protein folding."""
import
collections
import
collections
import
contextlib
import
functools
import
functools
import
numbers
import
numbers
from
typing
import
Mapping
from
typing
import
Mapping
...
@@ -25,6 +26,27 @@ import jax.numpy as jnp
...
@@ -25,6 +26,27 @@ import jax.numpy as jnp
import
numpy
as
np
import
numpy
as
np
def
bfloat16_creator
(
next_creator
,
shape
,
dtype
,
init
,
context
):
"""Creates float32 variables when bfloat16 is requested."""
if
context
.
original_dtype
==
jnp
.
bfloat16
:
dtype
=
jnp
.
float32
return
next_creator
(
shape
,
dtype
,
init
)
def
bfloat16_getter
(
next_getter
,
value
,
context
):
"""Casts float32 to bfloat16 when bfloat16 was originally requested."""
if
context
.
original_dtype
==
jnp
.
bfloat16
:
assert
value
.
dtype
==
jnp
.
float32
value
=
value
.
astype
(
jnp
.
bfloat16
)
return
next_getter
(
value
)
@
contextlib
.
contextmanager
def
bfloat16_context
():
with
hk
.
custom_creator
(
bfloat16_creator
),
hk
.
custom_getter
(
bfloat16_getter
):
yield
def
final_init
(
config
):
def
final_init
(
config
):
if
config
.
zero_init
:
if
config
.
zero_init
:
return
'zeros'
return
'zeros'
...
...
alphafold/notebooks/notebook_utils.py
View file @
9b18d6a9
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
"""Helper methods for the AlphaFold Colab notebook."""
"""Helper methods for the AlphaFold Colab notebook."""
import
enum
import
json
import
json
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
,
Tuple
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
,
Tuple
...
@@ -23,13 +22,7 @@ from matplotlib import pyplot as plt
...
@@ -23,13 +22,7 @@ from matplotlib import pyplot as plt
import
numpy
as
np
import
numpy
as
np
@
enum
.
unique
def
clean_and_validate_single_sequence
(
class
ModelType
(
enum
.
Enum
):
MONOMER
=
0
MULTIMER
=
1
def
clean_and_validate_sequence
(
input_sequence
:
str
,
min_length
:
int
,
max_length
:
int
)
->
str
:
input_sequence
:
str
,
min_length
:
int
,
max_length
:
int
)
->
str
:
"""Checks that the input sequence is ok and returns a clean version of it."""
"""Checks that the input sequence is ok and returns a clean version of it."""
# Remove all whitespaces, tabs and end lines; upper-case.
# Remove all whitespaces, tabs and end lines; upper-case.
...
@@ -54,41 +47,23 @@ def clean_and_validate_sequence(
...
@@ -54,41 +47,23 @@ def clean_and_validate_sequence(
return
clean_sequence
return
clean_sequence
def
validate_input
(
def
clean_and_
validate_input
_sequences
(
input_sequences
:
Sequence
[
str
],
input_sequences
:
Sequence
[
str
],
min_length
:
int
,
min_sequence_length
:
int
,
max_length
:
int
,
max_sequence_length
:
int
)
->
Sequence
[
str
]:
max_multimer_length
:
int
)
->
Tuple
[
Sequence
[
str
],
ModelType
]:
"""Validates and cleans input sequences."""
"""Validates and cleans input sequences and determines which model to use."""
sequences
=
[]
sequences
=
[]
for
input_sequence
in
input_sequences
:
for
input_sequence
in
input_sequences
:
if
input_sequence
.
strip
():
if
input_sequence
.
strip
():
input_sequence
=
clean_and_validate_sequence
(
input_sequence
=
clean_and_validate_
single_
sequence
(
input_sequence
=
input_sequence
,
input_sequence
=
input_sequence
,
min_length
=
min_length
,
min_length
=
min_
sequence_
length
,
max_length
=
max_length
)
max_length
=
max_
sequence_
length
)
sequences
.
append
(
input_sequence
)
sequences
.
append
(
input_sequence
)
if
len
(
sequences
)
==
1
:
if
sequences
:
print
(
'Using the single-chain model.'
)
return
sequences
return
sequences
,
ModelType
.
MONOMER
elif
len
(
sequences
)
>
1
:
total_multimer_length
=
sum
([
len
(
seq
)
for
seq
in
sequences
])
if
total_multimer_length
>
max_multimer_length
:
raise
ValueError
(
f
'The total length of multimer sequences is too long: '
f
'
{
total_multimer_length
}
, while the maximum is '
f
'
{
max_multimer_length
}
. Please use the full AlphaFold '
f
'system for long multimers.'
)
elif
total_multimer_length
>
1536
:
print
(
'WARNING: The accuracy of the system has not been fully validated '
'above 1536 residues, and you may experience long running times or '
f
'run out of memory for your complex with
{
total_multimer_length
}
'
'residues.'
)
print
(
f
'Using the multimer model with
{
len
(
sequences
)
}
sequences.'
)
return
sequences
,
ModelType
.
MULTIMER
else
:
else
:
raise
ValueError
(
'No input amino acid sequence provided, please provide at '
raise
ValueError
(
'No input amino acid sequence provided, please provide at '
'least one sequence.'
)
'least one sequence.'
)
...
...
alphafold/notebooks/notebook_utils_test.py
View file @
9b18d6a9
...
@@ -85,7 +85,7 @@ class NotebookUtilsTest(parameterized.TestCase):
...
@@ -85,7 +85,7 @@ class NotebookUtilsTest(parameterized.TestCase):
(
'DeepMind'
,
'DEEPMIND'
),
(
'A '
,
'A'
),
(
'
\t
A'
,
'A'
),
(
' A
\t\n
'
,
'A'
),
(
'DeepMind'
,
'DEEPMIND'
),
(
'A '
,
'A'
),
(
'
\t
A'
,
'A'
),
(
' A
\t\n
'
,
'A'
),
(
'ACDEFGHIKLMNPQRSTVWY'
,
'ACDEFGHIKLMNPQRSTVWY'
))
(
'ACDEFGHIKLMNPQRSTVWY'
,
'ACDEFGHIKLMNPQRSTVWY'
))
def
test_clean_and_validate_sequence_ok
(
self
,
sequence
,
exp_clean
):
def
test_clean_and_validate_sequence_ok
(
self
,
sequence
,
exp_clean
):
clean
=
notebook_utils
.
clean_and_validate_sequence
(
clean
=
notebook_utils
.
clean_and_validate_
single_
sequence
(
sequence
,
min_length
=
1
,
max_length
=
100
)
sequence
,
min_length
=
1
,
max_length
=
100
)
self
.
assertEqual
(
clean
,
exp_clean
)
self
.
assertEqual
(
clean
,
exp_clean
)
...
@@ -100,35 +100,29 @@ class NotebookUtilsTest(parameterized.TestCase):
...
@@ -100,35 +100,29 @@ class NotebookUtilsTest(parameterized.TestCase):
(
'bad_amino_acids_Z'
,
'ZZZZ'
,
'non-amino acid'
))
(
'bad_amino_acids_Z'
,
'ZZZZ'
,
'non-amino acid'
))
def
test_clean_and_validate_sequence_bad
(
self
,
sequence
,
exp_error
):
def
test_clean_and_validate_sequence_bad
(
self
,
sequence
,
exp_error
):
with
self
.
assertRaisesRegex
(
ValueError
,
f
'.*
{
exp_error
}
.*'
):
with
self
.
assertRaisesRegex
(
ValueError
,
f
'.*
{
exp_error
}
.*'
):
notebook_utils
.
clean_and_validate_sequence
(
notebook_utils
.
clean_and_validate_
single_
sequence
(
sequence
,
min_length
=
4
,
max_length
=
8
)
sequence
,
min_length
=
4
,
max_length
=
8
)
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
([
'A'
,
''
,
''
,
' '
,
'
\t
'
,
'
\t\n
'
,
''
,
''
],
[
'A'
],
([
'A'
,
''
,
''
,
' '
,
'
\t
'
,
'
\t\n
'
,
''
,
''
],
[
'A'
]),
notebook_utils
.
ModelType
.
MONOMER
),
([
''
,
'A'
],
[
'A'
]),
([
''
,
'A'
],
[
'A'
],
([
'A'
,
'C '
,
''
],
[
'A'
,
'C'
]),
notebook_utils
.
ModelType
.
MONOMER
),
([
''
,
'A'
,
''
,
'C '
],
[
'A'
,
'C'
]))
([
'A'
,
'C '
,
''
],
[
'A'
,
'C'
],
def
test_validate_input_ok
(
self
,
input_sequences
,
exp_sequences
):
notebook_utils
.
ModelType
.
MULTIMER
),
sequences
=
notebook_utils
.
clean_and_validate_input_sequences
(
([
''
,
'A'
,
''
,
'C '
],
[
'A'
,
'C'
],
notebook_utils
.
ModelType
.
MULTIMER
))
def
test_validate_input_ok
(
self
,
input_sequences
,
exp_sequences
,
exp_model_type
):
sequences
,
model_type
=
notebook_utils
.
validate_input
(
input_sequences
=
input_sequences
,
input_sequences
=
input_sequences
,
min_
length
=
1
,
max
_length
=
1
00
,
max_
multimer
_length
=
100
)
min_
sequence
_length
=
1
,
max_
sequence
_length
=
100
)
self
.
assertSequenceEqual
(
sequences
,
exp_sequences
)
self
.
assertSequenceEqual
(
sequences
,
exp_sequences
)
self
.
assertEqual
(
model_type
,
exp_model_type
)
@
parameterized
.
named_parameters
(
@
parameterized
.
named_parameters
(
(
'no_input_sequence'
,
[
''
,
'
\t
'
,
'
\n
'
],
'No input amino acid sequence'
),
(
'no_input_sequence'
,
[
''
,
'
\t
'
,
'
\n
'
],
'No input amino acid sequence'
),
(
'too_long_single'
,
[
'AAAAAAAAA'
,
'AAAA'
],
'Input sequence is too long'
),
(
'too_long_single'
,
[
'AAAAAAAAA'
,
'AAAA'
],
'Input sequence is too long'
),
(
'too_
long_multimer
'
,
[
'AAA
A
'
,
'AAAA
A
'
],
'
The total length of multimer
'
))
(
'too_
short_single
'
,
[
'AAA'
,
'AAAA'
],
'
Input sequence is too short
'
))
def
test_validate_input_bad
(
self
,
input_sequences
,
exp_error
):
def
test_validate_input_bad
(
self
,
input_sequences
,
exp_error
):
with
self
.
assertRaisesRegex
(
ValueError
,
f
'.*
{
exp_error
}
.*'
):
with
self
.
assertRaisesRegex
(
ValueError
,
f
'.*
{
exp_error
}
.*'
):
notebook_utils
.
validate_input
(
notebook_utils
.
clean_and_
validate_input
_sequences
(
input_sequences
=
input_sequences
,
input_sequences
=
input_sequences
,
min_sequence_length
=
4
,
m
in_length
=
4
,
max_length
=
8
,
max_multimer
_length
=
6
)
m
ax_sequence
_length
=
8
)
def
test_merge_chunked_msa_no_hits
(
self
):
def
test_merge_chunked_msa_no_hits
(
self
):
results
=
[
ONLY_QUERY_HIT
,
ONLY_QUERY_HIT
]
results
=
[
ONLY_QUERY_HIT
,
ONLY_QUERY_HIT
]
...
...
alphafold/relax/relax.py
View file @
9b18d6a9
...
@@ -56,7 +56,8 @@ class AmberRelaxation(object):
...
@@ -56,7 +56,8 @@ class AmberRelaxation(object):
self
.
_use_gpu
=
use_gpu
self
.
_use_gpu
=
use_gpu
def
process
(
self
,
*
,
def
process
(
self
,
*
,
prot
:
protein
.
Protein
)
->
Tuple
[
str
,
Dict
[
str
,
Any
],
np
.
ndarray
]:
prot
:
protein
.
Protein
)
->
Tuple
[
str
,
Dict
[
str
,
Any
],
Sequence
[
float
]]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out
=
amber_minimize
.
run_pipeline
(
out
=
amber_minimize
.
run_pipeline
(
prot
=
prot
,
max_iterations
=
self
.
_max_iterations
,
prot
=
prot
,
max_iterations
=
self
.
_max_iterations
,
...
@@ -73,12 +74,11 @@ class AmberRelaxation(object):
...
@@ -73,12 +74,11 @@ class AmberRelaxation(object):
'attempts'
:
out
[
'min_attempts'
],
'attempts'
:
out
[
'min_attempts'
],
'rmsd'
:
rmsd
'rmsd'
:
rmsd
}
}
pdb_str
=
amber_minimize
.
clean_protein
(
prot
)
min_pdb
=
out
[
'min_pdb'
]
min_pdb
=
utils
.
overwrite_pdb_coordinates
(
pdb_str
,
min_pos
)
min_pdb
=
utils
.
overwrite_b_factors
(
min_pdb
,
prot
.
b_factors
)
min_pdb
=
utils
.
overwrite_b_factors
(
min_pdb
,
prot
.
b_factors
)
utils
.
assert_equal_nonterminal_atom_types
(
utils
.
assert_equal_nonterminal_atom_types
(
protein
.
from_pdb_string
(
min_pdb
).
atom_mask
,
protein
.
from_pdb_string
(
min_pdb
).
atom_mask
,
prot
.
atom_mask
)
prot
.
atom_mask
)
violations
=
out
[
'structural_violations'
][
violations
=
out
[
'structural_violations'
][
'total_per_residue_violations_mask'
]
'total_per_residue_violations_mask'
]
.
tolist
()
return
min_pdb
,
debug_data
,
violations
return
min_pdb
,
debug_data
,
violations
alphafold/relax/relax_test.py
View file @
9b18d6a9
...
@@ -82,7 +82,7 @@ class RunAmberRelaxTest(absltest.TestCase):
...
@@ -82,7 +82,7 @@ class RunAmberRelaxTest(absltest.TestCase):
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
0
,
0
,
0
,
0
,
0
])
0
,
0
,
0
,
0
])
# Check no violations were added. Can't check exactly due to stochasticity.
# Check no violations were added. Can't check exactly due to stochasticity.
self
.
assertTrue
(
np
.
all
(
num_violations
<=
exp_num_violations
))
self
.
assertTrue
(
np
.
all
(
np
.
array
(
num_violations
)
<=
exp_num_violations
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
alphafold/relax/utils.py
View file @
9b18d6a9
...
@@ -17,17 +17,6 @@ import io
...
@@ -17,17 +17,6 @@ import io
from
alphafold.common
import
residue_constants
from
alphafold.common
import
residue_constants
from
Bio
import
PDB
from
Bio
import
PDB
import
numpy
as
np
import
numpy
as
np
from
simtk.openmm
import
app
as
openmm_app
from
simtk.openmm.app.internal.pdbstructure
import
PdbStructure
def
overwrite_pdb_coordinates
(
pdb_str
:
str
,
pos
)
->
str
:
pdb_file
=
io
.
StringIO
(
pdb_str
)
structure
=
PdbStructure
(
pdb_file
)
topology
=
openmm_app
.
PDBFile
(
structure
).
getTopology
()
with
io
.
StringIO
()
as
f
:
openmm_app
.
PDBFile
.
writeFile
(
topology
,
pos
,
f
)
return
f
.
getvalue
()
def
overwrite_b_factors
(
pdb_str
:
str
,
bfactors
:
np
.
ndarray
)
->
str
:
def
overwrite_b_factors
(
pdb_str
:
str
,
bfactors
:
np
.
ndarray
)
->
str
:
...
...
docker/Dockerfile
View file @
9b18d6a9
...
@@ -59,7 +59,7 @@ RUN conda install -qy conda==4.13.0 \
...
@@ -59,7 +59,7 @@ RUN conda install -qy conda==4.13.0 \
cudatoolkit
==
${
CUDA_VERSION
}
\
cudatoolkit
==
${
CUDA_VERSION
}
\
pdbfixer
\
pdbfixer
\
pip
\
pip
\
python
=
3.
7
\
python
=
3.
8
\
&&
conda clean
--all
--force-pkgs-dirs
--yes
&&
conda clean
--all
--force-pkgs-dirs
--yes
COPY
. /app/alphafold
COPY
. /app/alphafold
...
@@ -75,7 +75,7 @@ RUN pip3 install --upgrade pip --no-cache-dir \
...
@@ -75,7 +75,7 @@ RUN pip3 install --upgrade pip --no-cache-dir \
-f
https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-f
https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Apply OpenMM patch.
# Apply OpenMM patch.
WORKDIR
/opt/conda/lib/python3.
7
/site-packages
WORKDIR
/opt/conda/lib/python3.
8
/site-packages
RUN
patch
-p0
< /app/alphafold/docker/openmm.patch
RUN
patch
-p0
< /app/alphafold/docker/openmm.patch
# Add SETUID bit to the ldconfig binary so that non-root users can run it.
# Add SETUID bit to the ldconfig binary so that non-root users can run it.
...
...
docker/run_docker.py
View file @
9b18d6a9
...
@@ -133,7 +133,7 @@ def main(argv):
...
@@ -133,7 +133,7 @@ def main(argv):
# Path to the MGnify database for use by JackHMMER.
# Path to the MGnify database for use by JackHMMER.
mgnify_database_path
=
os
.
path
.
join
(
mgnify_database_path
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'mgnify'
,
'mgy_clusters_20
18_12
.fa'
)
FLAGS
.
data_dir
,
'mgnify'
,
'mgy_clusters_20
22_05
.fa'
)
# Path to the BFD database for use by HHblits.
# Path to the BFD database for use by HHblits.
bfd_database_path
=
os
.
path
.
join
(
bfd_database_path
=
os
.
path
.
join
(
...
@@ -144,9 +144,9 @@ def main(argv):
...
@@ -144,9 +144,9 @@ def main(argv):
small_bfd_database_path
=
os
.
path
.
join
(
small_bfd_database_path
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'small_bfd'
,
'bfd-first_non_consensus_sequences.fasta'
)
FLAGS
.
data_dir
,
'small_bfd'
,
'bfd-first_non_consensus_sequences.fasta'
)
# Path to the Uni
clust
30 database for use by HHblits.
# Path to the Uni
ref
30 database for use by HHblits.
uni
clust
30_database_path
=
os
.
path
.
join
(
uni
ref
30_database_path
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'uni
clust
30'
,
'
u
ni
clust30_2018_08'
,
'uniclust30_2018
_0
8
'
)
FLAGS
.
data_dir
,
'uni
ref
30'
,
'
U
ni
Ref30_2021
_0
3
'
)
# Path to the PDB70 database for use by HHsearch.
# Path to the PDB70 database for use by HHsearch.
pdb70_database_path
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'pdb70'
,
'pdb70'
)
pdb70_database_path
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'pdb70'
,
'pdb70'
)
...
@@ -199,7 +199,7 @@ def main(argv):
...
@@ -199,7 +199,7 @@ def main(argv):
database_paths
.
append
((
'small_bfd_database_path'
,
small_bfd_database_path
))
database_paths
.
append
((
'small_bfd_database_path'
,
small_bfd_database_path
))
else
:
else
:
database_paths
.
extend
([
database_paths
.
extend
([
(
'uni
clust
30_database_path'
,
uni
clust
30_database_path
),
(
'uni
ref
30_database_path'
,
uni
ref
30_database_path
),
(
'bfd_database_path'
,
bfd_database_path
),
(
'bfd_database_path'
,
bfd_database_path
),
])
])
for
name
,
path
in
database_paths
:
for
name
,
path
in
database_paths
:
...
...
docs/casp15_predictions.zip
0 → 100644
View file @
9b18d6a9
File added
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment