Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
bb3f51e5
Unverified
Commit
bb3f51e5
authored
Feb 07, 2024
by
Christina Floristean
Committed by
GitHub
Feb 07, 2024
Browse files
Merge pull request #405 from aqlaboratory/multimer
Full multimer merge
parents
ce211367
c33a0bd6
Changes
106
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3168 additions
and
670 deletions
+3168
-670
openfold/data/tools/hhblits.py
openfold/data/tools/hhblits.py
+4
-4
openfold/data/tools/hhsearch.py
openfold/data/tools/hhsearch.py
+23
-4
openfold/data/tools/hmmbuild.py
openfold/data/tools/hmmbuild.py
+137
-0
openfold/data/tools/hmmsearch.py
openfold/data/tools/hmmsearch.py
+137
-0
openfold/data/tools/jackhmmer.py
openfold/data/tools/jackhmmer.py
+42
-12
openfold/data/tools/kalign.py
openfold/data/tools/kalign.py
+1
-1
openfold/data/tools/parse_msa_files.py
openfold/data/tools/parse_msa_files.py
+57
-0
openfold/model/embedders.py
openfold/model/embedders.py
+532
-9
openfold/model/evoformer.py
openfold/model/evoformer.py
+387
-214
openfold/model/heads.py
openfold/model/heads.py
+9
-1
openfold/model/model.py
openfold/model/model.py
+188
-147
openfold/model/primitives.py
openfold/model/primitives.py
+23
-0
openfold/model/structure_module.py
openfold/model/structure_module.py
+501
-69
openfold/model/template.py
openfold/model/template.py
+185
-130
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+195
-8
openfold/np/protein.py
openfold/np/protein.py
+72
-8
openfold/np/relax/amber_minimize.py
openfold/np/relax/amber_minimize.py
+0
-57
openfold/np/relax/utils.py
openfold/np/relax/utils.py
+1
-1
openfold/np/residue_constants.py
openfold/np/residue_constants.py
+180
-5
openfold/utils/all_atom_multimer.py
openfold/utils/all_atom_multimer.py
+494
-0
No files found.
openfold/data/tools/hhblits.py
View file @
bb3f51e5
...
...
@@ -18,7 +18,7 @@ import glob
import
logging
import
os
import
subprocess
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Sequence
from
openfold.data.tools
import
utils
...
...
@@ -99,9 +99,9 @@ class HHBlits:
self
.
p
=
p
self
.
z
=
z
def
query
(
self
,
input_fasta_path
:
str
)
->
Mapping
[
str
,
Any
]:
def
query
(
self
,
input_fasta_path
:
str
)
->
List
[
Mapping
[
str
,
Any
]
]
:
"""Queries the database using HHblits."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.a3m"
)
db_cmd
=
[]
...
...
@@ -172,4 +172,4 @@ class HHBlits:
n_iter
=
self
.
n_iter
,
e_value
=
self
.
e_value
,
)
return
raw_output
return
[
raw_output
]
openfold/data/tools/hhsearch.py
View file @
bb3f51e5
...
...
@@ -18,8 +18,9 @@ import glob
import
logging
import
os
import
subprocess
from
typing
import
Sequence
from
typing
import
Sequence
,
Optional
from
openfold.data
import
parsers
from
openfold.data.tools
import
utils
...
...
@@ -62,11 +63,20 @@ class HHSearch:
f
"Could not find HHsearch database
{
database_path
}
"
)
def
query
(
self
,
a3m
:
str
)
->
str
:
@
property
def
output_format
(
self
)
->
str
:
return
'hhr'
@
property
def
input_format
(
self
)
->
str
:
return
'a3m'
def
query
(
self
,
a3m
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using HHsearch using a given a3m."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"query.a3m"
)
hhr_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.hhr"
)
output_dir
=
query_tmp_dir
if
output_dir
is
None
else
output_dir
hhr_path
=
os
.
path
.
join
(
output_dir
,
"hhsearch_output.hhr"
)
with
open
(
input_path
,
"w"
)
as
f
:
f
.
write
(
a3m
)
...
...
@@ -104,3 +114,12 @@ class HHSearch:
with
open
(
hhr_path
)
as
f
:
hhr
=
f
.
read
()
return
hhr
@
staticmethod
def
get_template_hits
(
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool"""
del
input_sequence
# Used by hmmsearch but not needed for hhsearch
return
parsers
.
parse_hhr
(
output_string
)
openfold/data/tools/hmmbuild.py
0 → 100644
View file @
bb3f51e5
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
import
os
import
re
import
subprocess
from
absl
import
logging
from
openfold.data.tools
import
utils
class
Hmmbuild
(
object
):
"""Python wrapper of the hmmbuild binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
singlemx
:
bool
=
False
):
"""Initializes the Python hmmbuild wrapper.
Args:
binary_path: The path to the hmmbuild executable.
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
just use a common substitution score matrix.
Raises:
RuntimeError: If hmmbuild binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
singlemx
=
singlemx
def
build_profile_from_sto
(
self
,
sto
:
str
,
model_construction
=
'fast'
)
->
str
:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
sto: A string with the aligned sequences in the Stockholm format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
return
self
.
_build_profile
(
sto
,
model_construction
=
model_construction
)
def
build_profile_from_a3m
(
self
,
a3m
:
str
)
->
str
:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
a3m: A string with the aligned sequences in the A3M format.
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
lines
=
[]
for
line
in
a3m
.
splitlines
():
if
not
line
.
startswith
(
'>'
):
line
=
re
.
sub
(
'[a-z]+'
,
''
,
line
)
# Remove inserted residues.
lines
.
append
(
line
+
'
\n
'
)
msa
=
''
.
join
(
lines
)
return
self
.
_build_profile
(
msa
,
model_construction
=
'fast'
)
def
_build_profile
(
self
,
msa
:
str
,
model_construction
:
str
=
'fast'
)
->
str
:
"""Builds a HMM for the aligned sequences given as an MSA string.
Args:
msa: A string with the aligned sequences, in A3M or STO format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
ValueError: If unspecified arguments are provided.
"""
if
model_construction
not
in
{
'hand'
,
'fast'
}:
raise
ValueError
(
f
'Invalid model_construction
{
model_construction
}
- only'
'hand and fast supported.'
)
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_query
=
os
.
path
.
join
(
query_tmp_dir
,
'query.msa'
)
output_hmm_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.hmm'
)
with
open
(
input_query
,
'w'
)
as
f
:
f
.
write
(
msa
)
cmd
=
[
self
.
binary_path
]
# If adding flags, we have to do so before the output and input:
if
model_construction
==
'hand'
:
cmd
.
append
(
f
'--
{
model_construction
}
'
)
if
self
.
singlemx
:
cmd
.
append
(
'--singlemx'
)
cmd
.
extend
([
'--amino'
,
output_hmm_path
,
input_query
,
])
logging
.
info
(
'Launching subprocess %s'
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'hmmbuild query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
logging
.
info
(
'hmmbuild stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
,
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
))
if
retcode
:
raise
RuntimeError
(
'hmmbuild failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
output_hmm_path
,
encoding
=
'utf-8'
)
as
f
:
hmm
=
f
.
read
()
return
hmm
openfold/data/tools/hmmsearch.py
0 → 100644
View file @
bb3f51e5
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import
os
import
subprocess
from
typing
import
Optional
,
Sequence
from
absl
import
logging
from
openfold.data
import
parsers
from
openfold.data.tools
import
hmmbuild
from
openfold.data.tools
import
utils
class
Hmmsearch
(
object
):
"""Python wrapper of the hmmsearch binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
hmmbuild_binary_path
:
str
,
database_path
:
str
,
flags
:
Optional
[
Sequence
[
str
]]
=
None
):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
hmmbuild_runner
=
hmmbuild
.
Hmmbuild
(
binary_path
=
hmmbuild_binary_path
)
self
.
database_path
=
database_path
if
flags
is
None
:
# Default hmmsearch run settings.
flags
=
[
'--F1'
,
'0.1'
,
'--F2'
,
'0.1'
,
'--F3'
,
'0.1'
,
'--incE'
,
'100'
,
'-E'
,
'100'
,
'--domE'
,
'100'
,
'--incdomE'
,
'100'
]
self
.
flags
=
flags
if
not
os
.
path
.
exists
(
self
.
database_path
):
logging
.
error
(
'Could not find hmmsearch database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find hmmsearch database
{
database_path
}
'
)
@
property
def
output_format
(
self
)
->
str
:
return
'sto'
@
property
def
input_format
(
self
)
->
str
:
return
'sto'
def
query
(
self
,
msa_sto
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm
=
self
.
hmmbuild_runner
.
build_profile_from_sto
(
msa_sto
,
model_construction
=
'hand'
)
return
self
.
query_with_hmm
(
hmm
,
output_dir
)
def
query_with_hmm
(
self
,
hmm
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given hmm."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
hmm_input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'query.hmm'
)
output_dir
=
query_tmp_dir
if
output_dir
is
None
else
output_dir
out_path
=
os
.
path
.
join
(
output_dir
,
'hmm_output.sto'
)
with
open
(
hmm_input_path
,
'w'
)
as
f
:
f
.
write
(
hmm
)
cmd
=
[
self
.
binary_path
,
'--noali'
,
# Don't include the alignment in stdout.
'--cpu'
,
'8'
]
# If adding flags, we have to do so before the output and input:
if
self
.
flags
:
cmd
.
extend
(
self
.
flags
)
cmd
.
extend
([
'-A'
,
out_path
,
hmm_input_path
,
self
.
database_path
,
])
logging
.
info
(
'Launching sub-process %s'
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
f
'hmmsearch (
{
os
.
path
.
basename
(
self
.
database_path
)
}
) query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
raise
RuntimeError
(
'hmmsearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
out_path
)
as
f
:
out_msa
=
f
.
read
()
return
out_msa
@
staticmethod
def
get_template_hits
(
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
template_hits
=
parsers
.
parse_hmmsearch_sto
(
output_string
,
input_sequence
,
)
return
template_hits
openfold/data/tools/jackhmmer.py
View file @
bb3f51e5
...
...
@@ -23,6 +23,7 @@ import subprocess
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
urllib
import
request
from
openfold.data
import
parsers
from
openfold.data.tools
import
utils
...
...
@@ -93,10 +94,13 @@ class Jackhmmer:
self
.
streaming_callback
=
streaming_callback
def
_query_chunk
(
self
,
input_fasta_path
:
str
,
database_path
:
str
self
,
input_fasta_path
:
str
,
database_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Mapping
[
str
,
Any
]:
"""Queries the database chunk using Jackhmmer."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
sto_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.sto"
)
# The F1/F2/F3 are the expected proportion to pass each of the filtering
...
...
@@ -167,8 +171,11 @@ class Jackhmmer:
with
open
(
tblout_path
)
as
f
:
tbl
=
f
.
read
()
if
(
max_sequences
is
None
):
with
open
(
sto_path
)
as
f
:
sto
=
f
.
read
()
else
:
sto
=
parsers
.
truncate_stockholm_msa
(
sto_path
,
max_sequences
)
raw_output
=
dict
(
sto
=
sto
,
...
...
@@ -180,10 +187,25 @@ class Jackhmmer:
return
raw_output
def
query
(
self
,
input_fasta_path
:
str
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
def
query
(
self
,
input_fasta_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Sequence
[
Sequence
[
Mapping
[
str
,
Any
]]]:
return
self
.
query_multiple
([
input_fasta_path
],
max_sequences
)
def
query_multiple
(
self
,
input_fasta_paths
:
Sequence
[
str
],
max_sequences
:
Optional
[
int
]
=
None
)
->
Sequence
[
Sequence
[
Mapping
[
str
,
Any
]]]:
"""Queries the database using Jackhmmer."""
if
self
.
num_streamed_chunks
is
None
:
return
[
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
)]
single_chunk_results
=
[]
for
input_fasta_path
in
input_fasta_paths
:
single_chunk_result
=
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
,
max_sequences
,
)
single_chunk_results
.
append
(
single_chunk_result
)
return
single_chunk_results
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_remote_chunk
=
lambda
db_idx
:
f
"
{
self
.
database_path
}
.
{
db_idx
}
"
...
...
@@ -198,7 +220,7 @@ class Jackhmmer:
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with
futures
.
ThreadPoolExecutor
(
max_workers
=
2
)
as
executor
:
chunked_output
=
[]
chunked_output
s
=
[
[]
for
_
in
range
(
len
(
input_fasta_paths
))
]
for
i
in
range
(
1
,
self
.
num_streamed_chunks
+
1
):
# Copy the chunk locally
if
i
==
1
:
...
...
@@ -216,13 +238,21 @@ class Jackhmmer:
# Run Jackhmmer with the chunk
future
.
result
()
chunked_output
.
append
(
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
))
for
fasta_idx
,
input_fasta_path
in
enumerate
(
input_fasta_paths
):
chunked_outputs
[
fasta_idx
].
append
(
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
),
max_sequences
)
)
# Remove the local copy of the chunk
os
.
remove
(
db_local_chunk
(
i
))
# Do not set next_future for the last chunk so that this works
# even for databases with only 1 chunk
if
(
i
<
self
.
num_streamed_chunks
):
future
=
next_future
if
self
.
streaming_callback
:
self
.
streaming_callback
(
i
)
return
chunked_output
return
chunked_output
s
openfold/data/tools/kalign.py
View file @
bb3f51e5
...
...
@@ -72,7 +72,7 @@ class Kalign:
"residues long. Got %s (%d residues)."
%
(
s
,
len
(
s
))
)
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_fasta_path
=
os
.
path
.
join
(
query_tmp_dir
,
"input.fasta"
)
output_a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.a3m"
)
...
...
openfold/data/tools/parse_msa_files.py
0 → 100644
View file @
bb3f51e5
import
os
,
argparse
,
pickle
,
tempfile
,
concurrent
from
openfold.data
import
parsers
from
concurrent.futures
import
ProcessPoolExecutor
def
parse_stockholm_file
(
alignment_dir
:
str
,
stockholm_file
:
str
):
path
=
os
.
path
.
join
(
alignment_dir
,
stockholm_file
)
file_name
,
_
=
os
.
path
.
splitext
(
stockholm_file
)
with
open
(
path
,
"r"
)
as
infile
:
msa
=
parsers
.
parse_stockholm
(
infile
.
read
())
infile
.
close
()
return
{
file_name
:
msa
}
def
parse_a3m_file
(
alignment_dir
:
str
,
a3m_file
:
str
):
path
=
os
.
path
.
join
(
alignment_dir
,
a3m_file
)
file_name
,
_
=
os
.
path
.
splitext
(
a3m_file
)
with
open
(
path
,
"r"
)
as
infile
:
msa
=
parsers
.
parse_a3m
(
infile
.
read
())
infile
.
close
()
return
{
file_name
:
msa
}
def
run_parse_all_msa_files_multiprocessing
(
stockholm_files
:
list
,
a3m_files
:
list
,
alignment_dir
:
str
):
# Number of workers based on the tasks
msa_results
=
{}
a3m_tasks
=
[(
alignment_dir
,
f
)
for
f
in
a3m_files
]
sto_tasks
=
[(
alignment_dir
,
f
)
for
f
in
stockholm_files
]
with
ProcessPoolExecutor
(
max_workers
=
len
(
a3m_tasks
)
+
len
(
sto_tasks
))
as
executor
:
a3m_futures
=
{
executor
.
submit
(
parse_a3m_file
,
*
task
):
task
for
task
in
a3m_tasks
}
sto_futures
=
{
executor
.
submit
(
parse_stockholm_file
,
*
task
):
task
for
task
in
sto_tasks
}
for
future
in
concurrent
.
futures
.
as_completed
(
a3m_futures
|
sto_futures
):
try
:
result
=
future
.
result
()
msa_results
.
update
(
result
)
except
Exception
as
exc
:
print
(
f
'Task generated an exception:
{
exc
}
'
)
return
msa_results
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Process msa files in parallel'
)
parser
.
add_argument
(
'--alignment_dir'
,
type
=
str
,
help
=
'path to alignment dir'
)
args
=
parser
.
parse_args
()
alignment_dir
=
args
.
alignment_dir
stockholm_files
=
[
i
for
i
in
os
.
listdir
(
alignment_dir
)
if
all
([
i
.
endswith
(
'.sto'
),
"hmm_output"
not
in
i
,
"uniprot_hits"
not
in
i
])]
a3m_files
=
[
i
for
i
in
os
.
listdir
(
alignment_dir
)
if
i
.
endswith
(
'.a3m'
)]
msa_data
=
run_parse_all_msa_files_multiprocessing
(
stockholm_files
,
a3m_files
,
alignment_dir
)
with
tempfile
.
NamedTemporaryFile
(
'wb'
,
suffix
=
'.pkl'
,
delete
=
False
)
as
outfile
:
pickle
.
dump
(
msa_data
,
outfile
)
print
(
outfile
.
name
)
if
__name__
==
"__main__"
:
main
()
\ No newline at end of file
openfold/model/embedders.py
View file @
bb3f51e5
...
...
@@ -13,12 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
,
Optional
from
openfold.utils
import
all_atom_multimer
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
dgram_from_positions
,
build_template_angle_feat
,
build_template_pair_feat
,
)
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
add
,
one_hot
from
openfold.model.template
import
(
TemplatePairStack
,
TemplatePointwiseAttention
,
)
from
openfold.utils
import
geometry
from
openfold.utils.tensor_utils
import
add
,
one_hot
,
tensor_tree_map
,
dict_multimap
class
InputEmbedder
(
nn
.
Module
):
...
...
@@ -99,12 +113,13 @@ class InputEmbedder(nn.Module):
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
batch: Dict containing
"target_feat":
Features of shape [*, N_res, tf_dim]
"residue_index":
Features of shape [*, N_res]
"msa_feat":
Features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
...
...
@@ -139,6 +154,161 @@ class InputEmbedder(nn.Module):
return
msa_emb
,
pair_emb
class
InputEmbedderMultimer
(
nn
.
Module
):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def
__init__
(
self
,
tf_dim
:
int
,
msa_dim
:
int
,
c_z
:
int
,
c_m
:
int
,
max_relative_idx
:
int
,
use_chain_relative
:
bool
,
max_relative_chain
:
int
,
**
kwargs
,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super
(
InputEmbedderMultimer
,
self
).
__init__
()
self
.
tf_dim
=
tf_dim
self
.
msa_dim
=
msa_dim
self
.
c_z
=
c_z
self
.
c_m
=
c_m
self
.
linear_tf_z_i
=
Linear
(
tf_dim
,
c_z
)
self
.
linear_tf_z_j
=
Linear
(
tf_dim
,
c_z
)
self
.
linear_tf_m
=
Linear
(
tf_dim
,
c_m
)
self
.
linear_msa_m
=
Linear
(
msa_dim
,
c_m
)
# RPE stuff
self
.
max_relative_idx
=
max_relative_idx
self
.
use_chain_relative
=
use_chain_relative
self
.
max_relative_chain
=
max_relative_chain
if
(
self
.
use_chain_relative
):
self
.
no_bins
=
(
2
*
max_relative_idx
+
2
+
1
+
2
*
max_relative_chain
+
2
)
else
:
self
.
no_bins
=
2
*
max_relative_idx
+
1
self
.
linear_relpos
=
Linear
(
self
.
no_bins
,
c_z
)
def
relpos
(
self
,
batch
):
pos
=
batch
[
"residue_index"
]
asym_id
=
batch
[
"asym_id"
]
asym_id_same
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:])
offset
=
pos
[...,
None
]
-
pos
[...,
None
,
:]
clipped_offset
=
torch
.
clamp
(
offset
+
self
.
max_relative_idx
,
0
,
2
*
self
.
max_relative_idx
)
rel_feats
=
[]
if
(
self
.
use_chain_relative
):
final_offset
=
torch
.
where
(
asym_id_same
,
clipped_offset
,
(
2
*
self
.
max_relative_idx
+
1
)
*
torch
.
ones_like
(
clipped_offset
)
)
boundaries
=
torch
.
arange
(
start
=
0
,
end
=
2
*
self
.
max_relative_idx
+
2
,
device
=
final_offset
.
device
)
rel_pos
=
one_hot
(
final_offset
,
boundaries
,
)
rel_feats
.
append
(
rel_pos
)
entity_id
=
batch
[
"entity_id"
]
entity_id_same
=
(
entity_id
[...,
None
]
==
entity_id
[...,
None
,
:])
rel_feats
.
append
(
entity_id_same
[...,
None
].
to
(
dtype
=
rel_pos
.
dtype
))
sym_id
=
batch
[
"sym_id"
]
rel_sym_id
=
sym_id
[...,
None
]
-
sym_id
[...,
None
,
:]
max_rel_chain
=
self
.
max_relative_chain
clipped_rel_chain
=
torch
.
clamp
(
rel_sym_id
+
max_rel_chain
,
0
,
2
*
max_rel_chain
,
)
final_rel_chain
=
torch
.
where
(
entity_id_same
,
clipped_rel_chain
,
(
2
*
max_rel_chain
+
1
)
*
torch
.
ones_like
(
clipped_rel_chain
)
)
boundaries
=
torch
.
arange
(
start
=
0
,
end
=
2
*
max_rel_chain
+
2
,
device
=
final_rel_chain
.
device
)
rel_chain
=
one_hot
(
final_rel_chain
,
boundaries
,
)
rel_feats
.
append
(
rel_chain
)
else
:
boundaries
=
torch
.
arange
(
start
=
0
,
end
=
2
*
self
.
max_relative_idx
+
1
,
device
=
clipped_offset
.
device
)
rel_pos
=
one_hot
(
clipped_offset
,
boundaries
,
)
rel_feats
.
append
(
rel_pos
)
rel_feat
=
torch
.
cat
(
rel_feats
,
dim
=-
1
).
to
(
self
.
linear_relpos
.
weight
.
dtype
)
return
self
.
linear_relpos
(
rel_feat
)
def
forward
(
self
,
batch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
tf
=
batch
[
"target_feat"
]
msa
=
batch
[
"msa_feat"
]
# [*, N_res, c_z]
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
# [*, N_res, N_res, c_z]
pair_emb
=
tf_emb_i
[...,
None
,
:]
+
tf_emb_j
[...,
None
,
:,
:]
pair_emb
=
pair_emb
+
self
.
relpos
(
batch
)
# [*, N_clust, N_res, c_m]
n_clust
=
msa
.
shape
[
-
3
]
tf_m
=
(
self
.
linear_tf_m
(
tf
)
.
unsqueeze
(
-
3
)
.
expand
(((
-
1
,)
*
len
(
tf
.
shape
[:
-
2
])
+
(
n_clust
,
-
1
,
-
1
)))
)
msa_emb
=
self
.
linear_msa_m
(
msa
)
+
tf_m
return
msa_emb
,
pair_emb
class
PreembeddingEmbedder
(
nn
.
Module
):
"""
Embeds the sequence pre-embedding passed to the model and the target_feat features.
...
...
@@ -335,7 +505,7 @@ class RecyclingEmbedder(nn.Module):
return
m_update
,
z_update
class
Template
A
ngleEmbedder
(
nn
.
Module
):
class
Template
Si
ngleEmbedder
(
nn
.
Module
):
"""
Embeds the "template_angle_feat" feature.
...
...
@@ -355,7 +525,7 @@ class TemplateAngleEmbedder(nn.Module):
c_out:
Output channel dimension
"""
super
(
Template
A
ngleEmbedder
,
self
).
__init__
()
super
(
Template
Si
ngleEmbedder
,
self
).
__init__
()
self
.
c_out
=
c_out
self
.
c_in
=
c_in
...
...
@@ -459,3 +629,356 @@ class ExtraMSAEmbedder(nn.Module):
x
=
self
.
linear
(
x
)
return
x
class
TemplateEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
TemplateEmbedder
,
self
).
__init__
()
self
.
config
=
config
self
.
template_single_embedder
=
TemplateSingleEmbedder
(
**
config
[
"template_single_embedder"
],
)
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
**
config
[
"template_pair_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
config
[
"template_pair_stack"
],
)
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
**
config
[
"template_pointwise_attention"
],
)
def
forward
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
chunk_size
,
_mask_trans
=
True
,
use_deepspeed_evo_attention
=
False
,
use_lma
=
False
,
inplace_safe
=
False
):
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds
=
[]
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
if
(
inplace_safe
):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair
=
z
.
new_zeros
(
z
.
shape
[:
-
3
]
+
(
n_templ
,
n
,
n
,
self
.
config
.
template_pair_embedder
.
c_out
)
)
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
).
squeeze
(
templ_dim
),
batch
,
)
# [*, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
self
.
config
.
use_unit_vector
,
inf
=
self
.
config
.
inf
,
eps
=
self
.
config
.
eps
,
**
self
.
config
.
distogram
,
).
to
(
z
.
dtype
)
t
=
self
.
template_pair_embedder
(
t
)
if
(
inplace_safe
):
t_pair
[...,
i
,
:,
:,
:]
=
t
else
:
pair_embeds
.
append
(
t
)
del
t
if
(
not
inplace_safe
):
t_pair
=
torch
.
stack
(
pair_embeds
,
dim
=
templ_dim
)
del
pair_embeds
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
t_pair
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
del
t_pair
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
use_lma
=
use_lma
,
)
t_mask
=
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
# Append singletons
t_mask
=
t_mask
.
reshape
(
*
t_mask
.
shape
,
*
([
1
]
*
(
len
(
t
.
shape
)
-
len
(
t_mask
.
shape
)))
)
if
(
inplace_safe
):
t
*=
t_mask
else
:
t
=
t
*
t_mask
ret
=
{}
ret
.
update
({
"template_pair_embedding"
:
t
})
del
t
if
self
.
config
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
batch
)
# [*, S_t, N, C_m]
a
=
self
.
template_single_embedder
(
template_angle_feat
)
ret
[
"template_single_embedding"
]
=
a
return
ret
class
TemplatePairEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
c_in
:
int
,
c_out
:
int
,
c_dgram
:
int
,
c_aatype
:
int
,
):
super
(
TemplatePairEmbedderMultimer
,
self
).
__init__
()
self
.
dgram_linear
=
Linear
(
c_dgram
,
c_out
,
init
=
'relu'
)
self
.
aatype_linear_1
=
Linear
(
c_aatype
,
c_out
,
init
=
'relu'
)
self
.
aatype_linear_2
=
Linear
(
c_aatype
,
c_out
,
init
=
'relu'
)
self
.
query_embedding_layer_norm
=
LayerNorm
(
c_in
)
self
.
query_embedding_linear
=
Linear
(
c_in
,
c_out
,
init
=
'relu'
)
self
.
pseudo_beta_mask_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
self
.
x_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
self
.
y_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
self
.
z_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
self
.
backbone_mask_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
def
forward
(
self
,
template_dgram
:
torch
.
Tensor
,
aatype_one_hot
:
torch
.
Tensor
,
query_embedding
:
torch
.
Tensor
,
pseudo_beta_mask
:
torch
.
Tensor
,
backbone_mask
:
torch
.
Tensor
,
multichain_mask_2d
:
torch
.
Tensor
,
unit_vector
:
geometry
.
Vec3Array
,
)
->
torch
.
Tensor
:
act
=
0.
pseudo_beta_mask_2d
=
(
pseudo_beta_mask
[...,
None
]
*
pseudo_beta_mask
[...,
None
,
:]
)
pseudo_beta_mask_2d
*=
multichain_mask_2d
template_dgram
*=
pseudo_beta_mask_2d
[...,
None
]
act
+=
self
.
dgram_linear
(
template_dgram
)
act
+=
self
.
pseudo_beta_mask_linear
(
pseudo_beta_mask_2d
[...,
None
])
aatype_one_hot
=
aatype_one_hot
.
to
(
template_dgram
.
dtype
)
act
+=
self
.
aatype_linear_1
(
aatype_one_hot
[...,
None
,
:,
:])
act
+=
self
.
aatype_linear_2
(
aatype_one_hot
[...,
None
,
:])
backbone_mask_2d
=
(
backbone_mask
[...,
None
]
*
backbone_mask
[...,
None
,
:]
)
backbone_mask_2d
*=
multichain_mask_2d
x
,
y
,
z
=
[(
coord
*
backbone_mask_2d
).
to
(
dtype
=
query_embedding
.
dtype
)
for
coord
in
unit_vector
]
act
+=
self
.
x_linear
(
x
[...,
None
])
act
+=
self
.
y_linear
(
y
[...,
None
])
act
+=
self
.
z_linear
(
z
[...,
None
])
act
+=
self
.
backbone_mask_linear
(
backbone_mask_2d
[...,
None
].
to
(
dtype
=
query_embedding
.
dtype
))
query_embedding
=
self
.
query_embedding_layer_norm
(
query_embedding
)
act
+=
self
.
query_embedding_linear
(
query_embedding
)
return
act
class
TemplateSingleEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
c_in
:
int
,
c_out
:
int
,
):
super
(
TemplateSingleEmbedderMultimer
,
self
).
__init__
()
self
.
template_single_embedder
=
Linear
(
c_in
,
c_out
)
self
.
template_projector
=
Linear
(
c_out
,
c_out
)
def
forward
(
self
,
batch
,
atom_pos
,
aatype_one_hot
,
):
out
=
{}
dtype
=
batch
[
"template_all_atom_positions"
].
dtype
template_chi_angles
,
template_chi_mask
=
(
all_atom_multimer
.
compute_chi_angles
(
atom_pos
,
batch
[
"template_all_atom_mask"
],
batch
[
"template_aatype"
],
)
)
template_features
=
torch
.
cat
(
[
aatype_one_hot
,
torch
.
sin
(
template_chi_angles
)
*
template_chi_mask
,
torch
.
cos
(
template_chi_angles
)
*
template_chi_mask
,
template_chi_mask
,
],
dim
=-
1
,
).
to
(
dtype
=
dtype
)
template_mask
=
template_chi_mask
[...,
0
].
to
(
dtype
=
dtype
)
template_activations
=
self
.
template_single_embedder
(
template_features
)
template_activations
=
torch
.
nn
.
functional
.
relu
(
template_activations
)
template_activations
=
self
.
template_projector
(
template_activations
,
)
out
[
"template_single_embedding"
]
=
(
template_activations
)
out
[
"template_mask"
]
=
template_mask
return
out
class
TemplateEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
TemplateEmbedderMultimer
,
self
).
__init__
()
self
.
config
=
config
self
.
template_pair_embedder
=
TemplatePairEmbedderMultimer
(
**
config
[
"template_pair_embedder"
],
)
self
.
template_single_embedder
=
TemplateSingleEmbedderMultimer
(
**
config
[
"template_single_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
config
[
"template_pair_stack"
],
)
self
.
linear_t
=
Linear
(
config
.
c_t
,
config
.
c_z
)
def
forward
(
self
,
batch
,
z
,
padding_mask_2d
,
templ_dim
,
chunk_size
,
multichain_mask_2d
,
_mask_trans
=
True
,
use_deepspeed_evo_attention
=
False
,
use_lma
=
False
,
inplace_safe
=
False
):
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
single_template_embeds
=
{}
act
=
0.
template_positions
,
pseudo_beta_mask
=
pseudo_beta_fn
(
single_template_feats
[
"template_aatype"
],
single_template_feats
[
"template_all_atom_positions"
],
single_template_feats
[
"template_all_atom_mask"
])
template_dgram
=
dgram_from_positions
(
template_positions
,
inf
=
self
.
config
.
inf
,
**
self
.
config
.
distogram
,
)
aatype_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
single_template_feats
[
"template_aatype"
],
22
,
)
raw_atom_pos
=
single_template_feats
[
"template_all_atom_positions"
]
# Vec3Arrays are required to be float32
atom_pos
=
geometry
.
Vec3Array
.
from_array
(
raw_atom_pos
.
to
(
dtype
=
torch
.
float32
))
rigid
,
backbone_mask
=
all_atom_multimer
.
make_backbone_affine
(
atom_pos
,
single_template_feats
[
"template_all_atom_mask"
],
single_template_feats
[
"template_aatype"
],
)
points
=
rigid
.
translation
rigid_vec
=
rigid
[...,
None
].
inverse
().
apply_to_point
(
points
)
unit_vector
=
rigid_vec
.
normalized
()
pair_act
=
self
.
template_pair_embedder
(
template_dgram
,
aatype_one_hot
,
z
,
pseudo_beta_mask
,
backbone_mask
,
multichain_mask_2d
,
unit_vector
,
)
single_template_embeds
[
"template_pair_embedding"
]
=
pair_act
single_template_embeds
.
update
(
self
.
template_single_embedder
(
single_template_feats
,
atom_pos
,
aatype_one_hot
,
)
)
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
)
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
# [*, N, N, C_z]
t
=
torch
.
sum
(
t
,
dim
=-
4
)
/
n_templ
t
=
torch
.
nn
.
functional
.
relu
(
t
)
t
=
self
.
linear_t
(
t
)
template_embeds
[
"template_pair_embedding"
]
=
t
return
template_embeds
openfold/model/evoformer.py
View file @
bb3f51e5
...
...
@@ -18,6 +18,7 @@ import torch
import
torch.nn
as
nn
from
typing
import
Tuple
,
Sequence
,
Optional
from
functools
import
partial
from
abc
import
ABC
,
abstractmethod
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.model.dropout
import
DropoutRowwise
,
DropoutColumnwise
...
...
@@ -36,6 +37,8 @@ from openfold.model.triangular_attention import (
from
openfold.model.triangular_multiplicative_update
import
(
TriangleMultiplicationOutgoing
,
TriangleMultiplicationIncoming
,
FusedTriangleMultiplicationIncoming
,
FusedTriangleMultiplicationOutgoing
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
openfold.utils.chunk_utils
import
chunk_layer
,
ChunkSizeTuner
...
...
@@ -117,35 +120,31 @@ class MSATransition(nn.Module):
return
m
class
EvoformerBlockCore
(
nn
.
Module
):
class
PairStack
(
nn
.
Module
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
transition_n
:
int
,
pair_dropout
:
float
,
fuse_projection_weights
:
bool
,
inf
:
float
,
eps
:
float
,
_is_extra_msa_stack
:
bool
=
False
,
eps
:
float
):
super
(
EvoformerBlockCore
,
self
).
__init__
()
super
(
PairStack
,
self
).
__init__
()
self
.
msa_transition
=
MSATransition
(
c_m
=
c_m
,
n
=
transition_n
,
if
fuse_projection_weights
:
self
.
tri_mul_out
=
FusedTriangleMultiplicationOutgoing
(
c_z
,
c_hidden_mul
,
)
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
self
.
tri_mul_in
=
FusedTriangleMultiplicationIncoming
(
c_z
,
c_hidden_
opm
,
c_hidden_
mul
,
)
else
:
self
.
tri_mul_out
=
TriangleMultiplicationOutgoing
(
c_z
,
c_hidden_mul
,
...
...
@@ -176,64 +175,30 @@ class EvoformerBlockCore(nn.Module):
self
.
ps_dropout_row_layer
=
DropoutRowwise
(
pair_dropout
)
def
forward
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
_attn_chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
if
(
_attn_chunk_size
is
None
):
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
m
,
z
=
input_tensors
m
=
add
(
m
,
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
,
),
inplace
=
inplace_safe
,
)
if
(
_offload_inference
and
inplace_safe
):
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
cpu
()
torch
.
cuda
.
empty_cache
()
m
,
z
=
input_tensors
opm
=
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
inplace_safe
)
if
(
_offload_inference
and
inplace_safe
):
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
input_tensors
[
0
]
=
input_tensors
[
0
].
cpu
()
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
opm
.
device
)
m
,
z
=
input_tensors
z
=
add
(
z
,
opm
,
inplace
=
inplace_safe
)
del
opm
tmu_update
=
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
,
inplace_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
if
(
not
inplace_safe
):
if
(
not
inplace_safe
):
z
=
z
+
self
.
ps_dropout_row_layer
(
tmu_update
)
else
:
z
=
tmu_update
...
...
@@ -246,7 +211,7 @@ class EvoformerBlockCore(nn.Module):
inplace_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
if
(
not
inplace_safe
):
if
(
not
inplace_safe
):
z
=
z
+
self
.
ps_dropout_row_layer
(
tmu_update
)
else
:
z
=
tmu_update
...
...
@@ -269,9 +234,8 @@ class EvoformerBlockCore(nn.Module):
)
z
=
z
.
transpose
(
-
2
,
-
3
)
if
(
inplace_safe
):
input_tensors
[
1
]
=
z
.
contiguous
()
z
=
input_tensors
[
1
]
if
(
inplace_safe
):
z
=
z
.
contiguous
()
z
=
add
(
z
,
self
.
ps_dropout_row_layer
(
...
...
@@ -289,9 +253,8 @@ class EvoformerBlockCore(nn.Module):
)
z
=
z
.
transpose
(
-
2
,
-
3
)
if
(
inplace_safe
):
input_tensors
[
1
]
=
z
.
contiguous
()
z
=
input_tensors
[
1
]
if
(
inplace_safe
):
z
=
z
.
contiguous
()
z
=
add
(
z
,
self
.
pair_transition
(
...
...
@@ -300,19 +263,11 @@ class EvoformerBlockCore(nn.Module):
inplace
=
inplace_safe
,
)
if
(
_offload_inference
and
inplace_safe
):
device
=
z
.
device
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
0
]
=
input_tensors
[
0
].
to
(
device
)
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
device
)
m
,
z
=
input_tensors
return
m
,
z
return
z
class
EvoformerBlock
(
nn
.
Module
):
class
MSABlock
(
nn
.
Module
,
ABC
):
@
abstractmethod
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
...
...
@@ -325,11 +280,14 @@ class EvoformerBlock(nn.Module):
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
no_column_attention
:
bool
,
opm_first
:
bool
,
fuse_projection_weights
:
bool
,
inf
:
float
,
eps
:
float
,
):
super
(
EvoformerBlock
,
self
).
__init__
()
super
(
MSABlock
,
self
).
__init__
()
self
.
opm_first
=
opm_first
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
c_m
=
c_m
,
...
...
@@ -339,30 +297,127 @@ class EvoformerBlock(nn.Module):
inf
=
inf
,
)
# Specifically, seqemb mode does not use column attention
self
.
no_column_attention
=
no_column_attention
if
not
self
.
no_column_attention
:
self
.
msa_att_col
=
MSAColumnAttention
(
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
msa_transition
=
MSATransition
(
c_m
=
c_m
,
n
=
transition_n
,
)
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
c_hidden_msa_att
,
no_heads_msa
,
c_z
,
c_hidden_opm
,
)
self
.
pair_stack
=
PairStack
(
c_z
=
c_z
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
pair_dropout
=
pair_dropout
,
fuse_projection_weights
=
fuse_projection_weights
,
inf
=
inf
,
eps
=
eps
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
def
_compute_opm
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
inplace_safe
:
bool
=
False
,
_offload_inference
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
.
core
=
EvoformerBlockCore
(
c_m
=
c_m
,
m
,
z
=
input_tensors
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: CPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
cpu
()
m
,
z
=
input_tensors
opm
=
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
inplace_safe
)
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: GPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
opm
.
device
)
m
,
z
=
input_tensors
z
=
add
(
z
,
opm
,
inplace
=
inplace_safe
)
del
opm
return
m
,
z
@
abstractmethod
def
forward
(
self
,
m
:
Optional
[
torch
.
Tensor
],
z
:
Optional
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
_offloadable_inputs
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
pass
class
EvoformerBlock
(
MSABlock
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
no_column_attention
:
bool
,
opm_first
:
bool
,
fuse_projection_weights
:
bool
,
inf
:
float
,
eps
:
float
,
):
super
(
EvoformerBlock
,
self
).
__init__
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden_msa_att
=
c_hidden_msa_att
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
opm_first
=
opm_first
,
fuse_projection_weights
=
fuse_projection_weights
,
inf
=
inf
,
eps
=
eps
)
# Specifically, seqemb mode does not use column attention
self
.
no_column_attention
=
no_column_attention
if
not
self
.
no_column_attention
:
self
.
msa_att_col
=
MSAColumnAttention
(
c_m
,
c_hidden_msa_att
,
no_heads_msa
,
inf
=
inf
,
eps
=
eps
,
)
def
forward
(
self
,
...
...
@@ -380,6 +435,9 @@ class EvoformerBlock(nn.Module):
_offload_inference
:
bool
=
False
,
_offloadable_inputs
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
...
...
@@ -391,6 +449,15 @@ class EvoformerBlock(nn.Module):
m
,
z
=
input_tensors
if
self
.
opm_first
:
del
m
,
z
m
,
z
=
self
.
_compute_opm
(
input_tensors
=
input_tensors
,
msa_mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
)
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
...
...
@@ -406,6 +473,14 @@ class EvoformerBlock(nn.Module):
inplace
=
inplace_safe
,
)
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: CPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
cpu
()
torch
.
cuda
.
empty_cache
()
m
,
z
=
input_tensors
# Specifically, column attention is not used in seqemb mode.
if
not
self
.
no_column_attention
:
m
=
add
(
m
,
...
...
@@ -420,28 +495,64 @@ class EvoformerBlock(nn.Module):
inplace
=
inplace_safe
,
)
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
input_tensors
[
1
]]
m
=
add
(
m
,
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
,
),
inplace
=
inplace_safe
,
)
if
not
self
.
opm_first
:
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
z
]
del
m
,
z
m
,
z
=
self
.
core
(
input_tensors
,
m
,
z
=
self
.
_compute_opm
(
input_tensors
=
input_tensors
,
msa_mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
)
if
(
_offload_inference
and
inplace_safe
):
# m: CPU, z: GPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
device
=
input_tensors
[
0
].
device
input_tensors
[
0
]
=
input_tensors
[
0
].
cpu
()
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
device
)
m
,
z
=
input_tensors
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
z
]
del
m
,
z
z
=
self
.
pair_stack
(
z
=
input_tensors
[
1
],
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
,
_offload_inference
=
_offload_inference
,
_attn_chunk_size
=
_attn_chunk_size
)
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: GPU
device
=
z
.
device
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
input_tensors
[
0
]
=
input_tensors
[
0
].
to
(
device
)
m
,
_
=
input_tensors
else
:
m
=
input_tensors
[
0
]
return
m
,
z
class
ExtraMSABlock
(
nn
.
Module
):
class
ExtraMSABlock
(
MSABlock
):
"""
Almost identical to the standard EvoformerBlock, except in that the
ExtraMSABlock uses GlobalAttention for MSA column attention and
...
...
@@ -460,42 +571,34 @@ class ExtraMSABlock(nn.Module):
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
opm_first
:
bool
,
fuse_projection_weights
:
bool
,
inf
:
float
,
eps
:
float
,
ckpt
:
bool
,
):
super
(
ExtraMSABlock
,
self
).
__init__
()
self
.
ckpt
=
ckpt
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
)
self
.
msa_att_col
=
MSAColumnGlobalAttention
(
c_in
=
c_m
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
eps
=
eps
,
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
core
=
EvoformerBlockCore
(
c_m
=
c_m
,
super
(
ExtraMSABlock
,
self
).
__init__
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden_msa_att
=
c_hidden_msa_att
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
opm_first
=
opm_first
,
fuse_projection_weights
=
fuse_projection_weights
,
inf
=
inf
,
eps
=
eps
)
self
.
ckpt
=
ckpt
self
.
msa_att_col
=
MSAColumnGlobalAttention
(
c_in
=
c_m
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
eps
=
eps
,
)
...
...
@@ -525,6 +628,15 @@ class ExtraMSABlock(nn.Module):
m
,
z
=
input_tensors
if
self
.
opm_first
:
del
m
,
z
m
,
z
=
self
.
_compute_opm
(
input_tensors
=
input_tensors
,
msa_mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
)
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
...
...
@@ -542,15 +654,25 @@ class ExtraMSABlock(nn.Module):
inplace
=
inplace_safe
,
)
if
(
not
inplace_safe
):
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
z
]
del
m
,
z
def
fn
(
input_tensors
):
m
=
add
(
input_tensors
[
0
],
m
,
z
=
input_tensors
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: CPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
cpu
()
torch
.
cuda
.
empty_cache
()
m
,
z
=
input_tensors
m
=
add
(
m
,
self
.
msa_att_col
(
input_tensors
[
0
],
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
...
...
@@ -558,27 +680,63 @@ class ExtraMSABlock(nn.Module):
inplace
=
inplace_safe
,
)
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
input_tensors
[
1
]]
m
=
add
(
m
,
self
.
msa_transition
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
),
inplace
=
inplace_safe
,
)
if
not
self
.
opm_first
:
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
z
]
del
m
del
m
,
z
m
,
z
=
self
.
core
(
input_tensors
,
m
,
z
=
self
.
_compute_opm
(
input_tensors
=
input_tensors
,
msa_mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
)
if
(
_offload_inference
and
inplace_safe
):
# m: CPU, z: GPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
device
=
input_tensors
[
0
].
device
input_tensors
[
0
]
=
input_tensors
[
0
].
cpu
()
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
device
)
m
,
z
=
input_tensors
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
z
]
del
m
,
z
z
=
self
.
pair_stack
(
input_tensors
[
1
],
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
,
_offload_inference
=
_offload_inference
,
_attn_chunk_size
=
_attn_chunk_size
)
m
=
input_tensors
[
0
]
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: GPU
device
=
z
.
device
del
m
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
input_tensors
[
0
]
=
input_tensors
[
0
].
to
(
device
)
m
,
_
=
input_tensors
return
m
,
z
if
(
torch
.
is_grad_enabled
()
and
self
.
ckpt
):
if
(
torch
.
is_grad_enabled
()
and
self
.
ckpt
):
checkpoint_fn
=
get_checkpoint_fn
()
m
,
z
=
checkpoint_fn
(
fn
,
input_tensors
)
else
:
...
...
@@ -609,8 +767,10 @@ class EvoformerStack(nn.Module):
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
blocks_per_ckpt
:
int
,
no_column_attention
:
bool
,
opm_first
:
bool
,
fuse_projection_weights
:
bool
,
blocks_per_ckpt
:
int
,
inf
:
float
,
eps
:
float
,
clear_cache_between_blocks
:
bool
=
False
,
...
...
@@ -646,11 +806,18 @@ class EvoformerStack(nn.Module):
Dropout rate for MSA activations
pair_dropout:
Dropout used for pair activations
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
no_column_attention:
When True, doesn't use column attention. Required for running
sequence embedding mode
opm_first:
When True, Outer Product Mean is performed at the beginning of
the Evoformer block instead of after the MSA Stack.
Used in Multimer pipeline.
fuse_projection_weights:
When True, uses FusedTriangleMultiplicativeUpdate variant in
the Pair Stack. Used in Multimer pipeline.
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
...
...
@@ -678,6 +845,8 @@ class EvoformerStack(nn.Module):
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
no_column_attention
=
no_column_attention
,
opm_first
=
opm_first
,
fuse_projection_weights
=
fuse_projection_weights
,
inf
=
inf
,
eps
=
eps
,
)
...
...
@@ -873,6 +1042,8 @@ class ExtraMSAStack(nn.Module):
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
opm_first
:
bool
,
fuse_projection_weights
:
bool
,
inf
:
float
,
eps
:
float
,
ckpt
:
bool
,
...
...
@@ -898,6 +1069,8 @@ class ExtraMSAStack(nn.Module):
transition_n
=
transition_n
,
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
opm_first
=
opm_first
,
fuse_projection_weights
=
fuse_projection_weights
,
inf
=
inf
,
eps
=
eps
,
ckpt
=
False
,
...
...
openfold/model/heads.py
View file @
bb3f51e5
...
...
@@ -76,9 +76,17 @@ class AuxiliaryHeads(nn.Module):
if
self
.
config
.
tm
.
enabled
:
tm_logits
=
self
.
tm
(
outputs
[
"pair"
])
aux_out
[
"tm_logits"
]
=
tm_logits
aux_out
[
"p
redicted_
tm_score"
]
=
compute_tm
(
aux_out
[
"ptm_score"
]
=
compute_tm
(
tm_logits
,
**
self
.
config
.
tm
)
asym_id
=
outputs
.
get
(
"asym_id"
)
if
asym_id
is
not
None
:
aux_out
[
"iptm_score"
]
=
compute_tm
(
tm_logits
,
asym_id
=
asym_id
,
interface
=
True
,
**
self
.
config
.
tm
)
aux_out
[
"weighted_ptm_score"
]
=
(
self
.
config
.
tm
[
"iptm_weight"
]
*
aux_out
[
"iptm_score"
]
+
self
.
config
.
tm
[
"ptm_weight"
]
*
aux_out
[
"ptm_score"
])
aux_out
.
update
(
compute_predicted_aligned_error
(
tm_logits
,
...
...
openfold/model/model.py
View file @
bb3f51e5
...
...
@@ -18,11 +18,20 @@ import weakref
import
torch
import
torch.nn
as
nn
from
openfold.data
import
data_transforms_multimer
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
build_extra_msa_feat
,
dgram_from_positions
,
atom14_to_atom37
,
)
from
openfold.utils.tensor_utils
import
masked_mean
from
openfold.model.embedders
import
(
InputEmbedder
,
InputEmbedderMultimer
,
RecyclingEmbedder
,
Template
Angle
Embedder
,
Template
Pair
Embedder
,
TemplateEmbedder
,
TemplateEmbedder
Multimer
,
ExtraMSAEmbedder
,
PreembeddingEmbedder
,
)
...
...
@@ -75,9 +84,13 @@ class AlphaFold(nn.Module):
self
.
seqemb_mode
=
config
.
globals
.
seqemb_mode_enabled
# Main trunk + structure module
if
self
.
globals
.
is_multimer
:
self
.
input_embedder
=
InputEmbedderMultimer
(
**
self
.
config
[
"input_embedder"
]
)
elif
self
.
seqemb_mode
:
# If using seqemb mode, embed the sequence embeddings passed
# to the model ("preembeddings") instead of embedding the sequence
if
self
.
seqemb_mode
:
self
.
input_embedder
=
PreembeddingEmbedder
(
**
self
.
config
[
"preembedding_embedder"
],
)
...
...
@@ -85,25 +98,22 @@ class AlphaFold(nn.Module):
self
.
input_embedder
=
InputEmbedder
(
**
self
.
config
[
"input_embedder"
],
)
self
.
recycling_embedder
=
RecyclingEmbedder
(
**
self
.
config
[
"recycling_embedder"
],
)
if
(
self
.
template_config
.
enabled
):
self
.
template_angle_embedder
=
TemplateAngleEmbedder
(
**
self
.
template_config
[
"template_angle_embedder"
],
)
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
**
self
.
template_config
[
"template_pair_embedder"
],
if
self
.
template_config
.
enabled
:
if
self
.
globals
.
is_multimer
:
self
.
template_embedder
=
TemplateEmbedderMultimer
(
self
.
template_config
,
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
self
.
template_config
[
"template_pair_stack"
],
)
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
**
self
.
template_config
[
"template_pointwise_attention"
],
else
:
self
.
template_embedder
=
TemplateEmbedder
(
self
.
template_config
,
)
if
(
self
.
extra_msa_config
.
enabled
)
:
if
self
.
extra_msa_config
.
enabled
:
self
.
extra_msa_embedder
=
ExtraMSAEmbedder
(
**
self
.
extra_msa_config
[
"extra_msa_embedder"
],
)
...
...
@@ -114,113 +124,87 @@ class AlphaFold(nn.Module):
self
.
evoformer
=
EvoformerStack
(
**
self
.
config
[
"evoformer_stack"
],
)
self
.
structure_module
=
StructureModule
(
is_multimer
=
self
.
globals
.
is_multimer
,
**
self
.
config
[
"structure_module"
],
)
self
.
aux_heads
=
AuxiliaryHeads
(
self
.
config
[
"heads"
],
)
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
):
if
(
self
.
template_config
.
offload_templates
):
def
embed_templates
(
self
,
batch
,
feats
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
):
if
self
.
globals
.
is_multimer
:
asym_id
=
feats
[
"asym_id"
]
multichain_mask_2d
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
)
template_embeds
=
self
.
template_embedder
(
batch
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
templ_dim
,
chunk_size
=
self
.
globals
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
use_deepspeed_evo_attention
=
self
.
globals
.
use_deepspeed_evo_attention
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
)
feats
[
"template_torsion_angles_mask"
]
=
(
template_embeds
[
"template_mask"
]
)
else
:
if
self
.
template_config
.
offload_templates
:
return
embed_templates_offload
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
=
inplace_safe
,
)
elif
(
self
.
template_config
.
average_templates
)
:
elif
self
.
template_config
.
average_templates
:
return
embed_templates_average
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
=
inplace_safe
,
)
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds
=
[]
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
if
(
inplace_safe
):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair
=
z
.
new_zeros
(
z
.
shape
[:
-
3
]
+
(
n_templ
,
n
,
n
,
self
.
globals
.
c_t
)
)
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
).
squeeze
(
templ_dim
),
template_embeds
=
self
.
template_embedder
(
batch
,
)
# [*, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
self
.
config
.
template
.
use_unit_vector
,
inf
=
self
.
config
.
template
.
inf
,
eps
=
self
.
config
.
template
.
eps
,
**
self
.
config
.
template
.
distogram
,
).
to
(
z
.
dtype
)
t
=
self
.
template_pair_embedder
(
t
)
if
(
inplace_safe
):
t_pair
[...,
i
,
:,
:,
:]
=
t
else
:
pair_embeds
.
append
(
t
)
del
t
if
(
not
inplace_safe
):
t_pair
=
torch
.
stack
(
pair_embeds
,
dim
=
templ_dim
)
del
pair_embeds
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
t_pair
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
templ_dim
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_deepspeed_evo_attention
=
self
.
globals
.
use_deepspeed_evo_attention
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
del
t_pair
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
use_lma
=
self
.
globals
.
use_lma
,
)
t_mask
=
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
# Append singletons
t_mask
=
t_mask
.
reshape
(
*
t_mask
.
shape
,
*
([
1
]
*
(
len
(
t
.
shape
)
-
len
(
t_mask
.
shape
)))
_mask_trans
=
self
.
config
.
_mask_trans
)
if
(
inplace_safe
):
t
*=
t_mask
else
:
t
=
t
*
t_mask
ret
=
{}
ret
.
update
({
"template_pair_embedding"
:
t
})
del
t
return
template_embeds
if
self
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
batch
)
def
tolerance_reached
(
self
,
prev_pos
,
next_pos
,
mask
,
eps
=
1e-8
)
->
bool
:
"""
Early stopping criteria based on criteria used in
AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
Args:
prev_pos: Previous atom positions in atom37/14 representation
next_pos: Current atom positions in atom37/14 representation
mask: 1-D sequence mask
eps: Epsilon used in square root calculation
Returns:
Whether to stop recycling early based on the desired tolerance.
"""
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
def
distances
(
points
):
"""Compute all pairwise distances for a set of points."""
d
=
points
[...,
None
,
:]
-
points
[...,
None
,
:,
:]
return
torch
.
sqrt
(
torch
.
sum
(
d
**
2
,
dim
=-
1
))
ret
[
"template_angle_embedding"
]
=
a
if
self
.
config
.
recycle_early_stop_tolerance
<
0
:
return
False
return
ret
ca_idx
=
residue_constants
.
atom_order
[
'CA'
]
sq_diff
=
(
distances
(
prev_pos
[...,
ca_idx
,
:])
-
distances
(
next_pos
[...,
ca_idx
,
:]))
**
2
mask
=
mask
[...,
None
]
*
mask
[...,
None
,
:]
sq_diff
=
masked_mean
(
mask
=
mask
,
value
=
sq_diff
,
dim
=
list
(
range
(
len
(
mask
.
shape
))))
diff
=
torch
.
sqrt
(
sq_diff
+
eps
).
item
()
return
diff
<=
self
.
config
.
recycle_early_stop_tolerance
def
iteration
(
self
,
feats
,
prevs
,
_recycle
=
True
):
# Primary output dictionary
...
...
@@ -229,7 +213,7 @@ class AlphaFold(nn.Module):
# This needs to be done manually for DeepSpeed's sake
dtype
=
next
(
self
.
parameters
()).
dtype
for
k
in
feats
:
if
(
feats
[
k
].
dtype
==
torch
.
float32
)
:
if
feats
[
k
].
dtype
==
torch
.
float32
:
feats
[
k
]
=
feats
[
k
].
to
(
dtype
=
dtype
)
# Grab some data about the input
...
...
@@ -248,18 +232,22 @@ class AlphaFold(nn.Module):
pair_mask
=
seq_mask
[...,
None
]
*
seq_mask
[...,
None
,
:]
msa_mask
=
feats
[
"msa_mask"
]
## Initialize the SingleSeq and pair representations
if
self
.
globals
.
is_multimer
:
# Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m
,
z
=
self
.
input_embedder
(
feats
)
elif
self
.
seqemb_mode
:
# Initialize the SingleSeq and pair representations
# m: [*, 1, N, C_m]
# z: [*, N, N, C_z]
if
self
.
seqemb_mode
:
m
,
z
=
self
.
input_embedder
(
feats
[
"target_feat"
],
feats
[
"residue_index"
],
feats
[
"seq_embedding"
]
)
else
:
#
#
Initialize the MSA and pair representations
# Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m
,
z
=
self
.
input_embedder
(
...
...
@@ -293,12 +281,12 @@ class AlphaFold(nn.Module):
requires_grad
=
False
,
)
x_prev
=
pseudo_beta_fn
(
pseudo_beta_
x_prev
=
pseudo_beta_fn
(
feats
[
"aatype"
],
x_prev
,
None
).
to
(
dtype
=
z
.
dtype
)
# The recycling embedder is memory-intensive, so we offload first
if
(
self
.
globals
.
offload_inference
and
inplace_safe
)
:
if
self
.
globals
.
offload_inference
and
inplace_safe
:
m
=
m
.
cpu
()
z
=
z
.
cpu
()
...
...
@@ -307,11 +295,13 @@ class AlphaFold(nn.Module):
m_1_prev_emb
,
z_prev_emb
=
self
.
recycling_embedder
(
m_1_prev
,
z_prev
,
x_prev
,
pseudo_beta_
x_prev
,
inplace_safe
=
inplace_safe
,
)
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
del
pseudo_beta_x_prev
if
self
.
globals
.
offload_inference
and
inplace_safe
:
m
=
m
.
to
(
m_1_prev_emb
.
device
)
z
=
z
.
to
(
z_prev
.
device
)
...
...
@@ -324,15 +314,17 @@ class AlphaFold(nn.Module):
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
del
m_1_prev
,
z_prev
,
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if
self
.
config
.
template
.
enabled
:
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
}
template_embeds
=
self
.
embed_templates
(
template_feats
,
feats
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
...
...
@@ -345,26 +337,40 @@ class AlphaFold(nn.Module):
inplace_safe
,
)
if
"template_angle_embedding"
in
template_embeds
:
if
(
"template_single_embedding"
in
template_embeds
):
# [*, S = S_c + S_t, N, C_m]
m
=
torch
.
cat
(
[
m
,
template_embeds
[
"template_
a
ngle_embedding"
]],
[
m
,
template_embeds
[
"template_
si
ngle_embedding"
]],
dim
=-
3
)
# [*, S, N]
if
not
self
.
globals
.
is_multimer
:
torsion_angles_mask
=
feats
[
"template_torsion_angles_mask"
]
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
dim
=-
2
)
else
:
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
template_embeds
[
"template_mask"
]],
dim
=-
2
,
)
# Embed extra MSA features + merge with pairwise embeddings
if
self
.
config
.
extra_msa
.
enabled
:
if
self
.
globals
.
is_multimer
:
extra_msa_fn
=
data_transforms_multimer
.
build_extra_msa_feat
else
:
extra_msa_fn
=
build_extra_msa_feat
# [*, S_e, N, C_e]
a
=
self
.
extra_msa_embedder
(
build_extra_msa_feat
(
feats
))
extra_msa_feat
=
extra_msa_fn
(
feats
).
to
(
dtype
=
z
.
dtype
)
a
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
if
(
self
.
globals
.
offload_inference
)
:
if
self
.
globals
.
offload_inference
:
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors
=
[
a
,
z
]
...
...
@@ -399,7 +405,7 @@ class AlphaFold(nn.Module):
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if
(
self
.
globals
.
offload_inference
)
:
if
self
.
globals
.
offload_inference
:
input_tensors
=
[
m
,
z
]
del
m
,
z
m
,
z
,
s
=
self
.
evoformer
.
_forward_offload
(
...
...
@@ -455,10 +461,34 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
z_prev
=
outputs
[
"pair"
]
early_stop
=
False
if
self
.
globals
.
is_multimer
:
early_stop
=
self
.
tolerance_reached
(
x_prev
,
outputs
[
"final_atom_positions"
],
seq_mask
)
del
x_prev
# [*, N, 3]
x_prev
=
outputs
[
"final_atom_positions"
]
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
,
early_stop
def
_disable_activation_checkpointing
(
self
):
self
.
template_embedder
.
template_pair_stack
.
blocks_per_ckpt
=
None
self
.
evoformer
.
blocks_per_ckpt
=
None
for
b
in
self
.
extra_msa_stack
.
blocks
:
b
.
ckpt
=
False
def
_enable_activation_checkpointing
(
self
):
self
.
template_embedder
.
template_pair_stack
.
blocks_per_ckpt
=
(
self
.
config
.
template
.
template_pair_stack
.
blocks_per_ckpt
)
self
.
evoformer
.
blocks_per_ckpt
=
(
self
.
config
.
evoformer_stack
.
blocks_per_ckpt
)
for
b
in
self
.
extra_msa_stack
.
blocks
:
b
.
ckpt
=
self
.
config
.
extra_msa
.
extra_msa_stack
.
ckpt
def
forward
(
self
,
batch
):
"""
...
...
@@ -519,13 +549,15 @@ class AlphaFold(nn.Module):
# Main recycling loop
num_iters
=
batch
[
"aatype"
].
shape
[
-
1
]
early_stop
=
False
num_recycles
=
0
for
cycle_no
in
range
(
num_iters
):
# Select the features for the current recycling cycle
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
cycle_no
==
(
num_iters
-
1
)
is_final_iter
=
cycle_no
==
(
num_iters
-
1
)
or
early_stop
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
if
is_final_iter
:
# Sidestep AMP bug (PyTorch issue #65766)
...
...
@@ -533,16 +565,25 @@ class AlphaFold(nn.Module):
torch
.
clear_autocast_cache
()
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
outputs
,
m_1_prev
,
z_prev
,
x_prev
,
early_stop
=
self
.
iteration
(
feats
,
prevs
,
_recycle
=
(
num_iters
>
1
)
)
if
(
not
is_final_iter
):
num_recycles
+=
1
if
not
is_final_iter
:
del
outputs
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
del
m_1_prev
,
z_prev
,
x_prev
else
:
break
outputs
[
"num_recycles"
]
=
torch
.
tensor
(
num_recycles
,
device
=
feats
[
"aatype"
].
device
)
if
"asym_id"
in
batch
:
outputs
[
"asym_id"
]
=
feats
[
"asym_id"
]
# Run auxiliary heads
outputs
.
update
(
self
.
aux_heads
(
outputs
))
...
...
openfold/model/primitives.py
View file @
bb3f51e5
...
...
@@ -131,6 +131,7 @@ class Linear(nn.Linear):
bias
:
bool
=
True
,
init
:
str
=
"default"
,
init_fn
:
Optional
[
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
],
None
]]
=
None
,
precision
=
None
):
"""
Args:
...
...
@@ -182,6 +183,28 @@ class Linear(nn.Linear):
else
:
raise
ValueError
(
"Invalid init string."
)
self
.
precision
=
precision
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
input
.
dtype
deepspeed_is_initialized
=
(
deepspeed_is_installed
and
deepspeed
.
comm
.
comm
.
is_initialized
()
)
if
self
.
precision
is
not
None
:
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
bias
=
self
.
bias
.
to
(
dtype
=
self
.
precision
)
if
self
.
bias
is
not
None
else
None
return
nn
.
functional
.
linear
(
input
.
to
(
dtype
=
self
.
precision
),
self
.
weight
.
to
(
dtype
=
self
.
precision
),
bias
).
to
(
dtype
=
d
)
if
d
is
torch
.
bfloat16
and
not
deepspeed_is_initialized
:
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
bias
=
self
.
bias
.
to
(
dtype
=
d
)
if
self
.
bias
is
not
None
else
None
return
nn
.
functional
.
linear
(
input
,
self
.
weight
.
to
(
dtype
=
d
),
bias
)
return
nn
.
functional
.
linear
(
input
,
self
.
weight
,
self
.
bias
)
class
LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
eps
=
1e-5
):
...
...
openfold/model/structure_module.py
View file @
bb3f51e5
...
...
@@ -20,7 +20,7 @@ from operator import mul
import
torch
import
torch.nn
as
nn
from
typing
import
Optional
,
Tuple
,
Sequence
from
typing
import
Optional
,
Tuple
,
Sequence
,
Union
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
ipa_point_weights_init_
from
openfold.np.residue_constants
import
(
...
...
@@ -29,6 +29,9 @@ from openfold.np.residue_constants import (
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
)
from
openfold.utils.geometry.quat_rigid
import
QuatRigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.vector
import
Vec3Array
,
square_euclidean_distance
from
openfold.utils.feats
import
(
frames_and_literature_positions_to_atom14_pos
,
torsion_angles_to_frames
,
...
...
@@ -158,6 +161,51 @@ class AngleResnet(nn.Module):
return
unnormalized_s
,
s
class
PointProjection
(
nn
.
Module
):
def
__init__
(
self
,
c_hidden
:
int
,
num_points
:
int
,
no_heads
:
int
,
is_multimer
:
bool
,
return_local_points
:
bool
=
False
,
):
super
().
__init__
()
self
.
return_local_points
=
return_local_points
self
.
no_heads
=
no_heads
self
.
num_points
=
num_points
self
.
is_multimer
=
is_multimer
# Multimer requires this to be run with fp32 precision during training
precision
=
torch
.
float32
if
self
.
is_multimer
else
None
self
.
linear
=
Linear
(
c_hidden
,
no_heads
*
3
*
num_points
,
precision
=
precision
)
def
forward
(
self
,
activations
:
torch
.
Tensor
,
rigids
:
Union
[
Rigid
,
Rigid3Array
],
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
# TODO: Needs to run in high precision during training
points_local
=
self
.
linear
(
activations
)
out_shape
=
points_local
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
self
.
num_points
,
3
)
if
self
.
is_multimer
:
points_local
=
points_local
.
view
(
points_local
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
)
)
points_local
=
torch
.
split
(
points_local
,
points_local
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
points_local
=
torch
.
stack
(
points_local
,
dim
=-
1
).
view
(
out_shape
)
points_global
=
rigids
[...,
None
,
None
].
apply
(
points_local
)
if
(
self
.
return_local_points
):
return
points_global
,
points_local
return
points_global
class
InvariantPointAttention
(
nn
.
Module
):
"""
Implements Algorithm 22.
...
...
@@ -172,6 +220,7 @@ class InvariantPointAttention(nn.Module):
no_v_points
:
int
,
inf
:
float
=
1e5
,
eps
:
float
=
1e-8
,
is_multimer
:
bool
=
False
,
):
"""
Args:
...
...
@@ -198,22 +247,46 @@ class InvariantPointAttention(nn.Module):
self
.
no_v_points
=
no_v_points
self
.
inf
=
inf
self
.
eps
=
eps
self
.
is_multimer
=
is_multimer
# These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default
# Lecun initialization.
hc
=
self
.
c_hidden
*
self
.
no_heads
self
.
linear_q
=
Linear
(
self
.
c_s
,
hc
)
self
.
linear_kv
=
Linear
(
self
.
c_s
,
2
*
hc
)
self
.
linear_q
=
Linear
(
self
.
c_s
,
hc
,
bias
=
(
not
is_multimer
))
hpq
=
self
.
no_heads
*
self
.
no_qk_points
*
3
self
.
linear_q_points
=
Linear
(
self
.
c_s
,
hpq
)
self
.
linear_q_points
=
PointProjection
(
self
.
c_s
,
self
.
no_qk_points
,
self
.
no_heads
,
self
.
is_multimer
)
hpkv
=
self
.
no_heads
*
(
self
.
no_qk_points
+
self
.
no_v_points
)
*
3
self
.
linear_kv_points
=
Linear
(
self
.
c_s
,
hpkv
)
if
(
is_multimer
):
self
.
linear_k
=
Linear
(
self
.
c_s
,
hc
,
bias
=
False
)
self
.
linear_v
=
Linear
(
self
.
c_s
,
hc
,
bias
=
False
)
self
.
linear_k_points
=
PointProjection
(
self
.
c_s
,
self
.
no_qk_points
,
self
.
no_heads
,
self
.
is_multimer
)
hpv
=
self
.
no_heads
*
self
.
no_v_points
*
3
self
.
linear_v_points
=
PointProjection
(
self
.
c_s
,
self
.
no_v_points
,
self
.
no_heads
,
self
.
is_multimer
)
else
:
self
.
linear_kv
=
Linear
(
self
.
c_s
,
2
*
hc
)
self
.
linear_kv_points
=
PointProjection
(
self
.
c_s
,
self
.
no_qk_points
+
self
.
no_v_points
,
self
.
no_heads
,
self
.
is_multimer
)
self
.
linear_b
=
Linear
(
self
.
c_z
,
self
.
no_heads
)
...
...
@@ -231,8 +304,8 @@ class InvariantPointAttention(nn.Module):
def
forward
(
self
,
s
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
,
r
:
Rigid
,
z
:
torch
.
Tensor
,
r
:
Union
[
Rigid
,
Rigid3Array
]
,
mask
:
torch
.
Tensor
,
inplace_safe
:
bool
=
False
,
_offload_inference
:
bool
=
False
,
...
...
@@ -251,7 +324,7 @@ class InvariantPointAttention(nn.Module):
Returns:
[*, N_res, C_s] single representation update
"""
if
(
_offload_inference
and
inplace_safe
):
if
(
_offload_inference
and
inplace_safe
):
z
=
_z_reference_list
else
:
z
=
[
z
]
...
...
@@ -261,41 +334,40 @@ class InvariantPointAttention(nn.Module):
#######################################
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
s
)
kv
=
self
.
linear_kv
(
s
)
# [*, N_res, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, H,
2 * C_hidden
]
kv
=
kv
.
view
(
kv
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
)
)
# [*, N_res, H,
P_qk
]
q_pts
=
self
.
linear_q_points
(
s
,
r
)
# [*, N_res, H, C_hidden]
k
,
v
=
torch
.
split
(
kv
,
self
.
c_hidden
,
dim
=-
1
)
# The following two blocks are equivalent
# They're separated only to preserve compatibility with old AF weights
if
(
self
.
is_multimer
):
# [*, N_res, H * C_hidden]
k
=
self
.
linear_k
(
s
)
v
=
self
.
linear_v
(
s
)
# [*, N_res, H * P_q * 3]
q_pts
=
self
.
linear_q_points
(
s
)
# [*, N_res, H, C_hidden]
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# This is kind of clunky, but it's how the original does it
# [*, N_res, H * P_q, 3]
q_pts
=
torch
.
split
(
q_pts
,
q_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
q_pts
=
torch
.
stack
(
q_pts
,
dim
=-
1
)
q_pts
=
r
[...,
None
].
apply
(
q_pts
)
# [*, N_res, H, P_qk, 3]
k_pts
=
self
.
linear_k_points
(
s
,
r
)
# [*, N_res, H, P_q, 3]
q_pts
=
q_pts
.
view
(
q_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
self
.
no_qk_points
,
3
)
)
# [*, N_res, H, P_v, 3]
v_pts
=
self
.
linear_v_points
(
s
,
r
)
else
:
# [*, N_res, H * 2 * C_hidden]
kv
=
self
.
linear_kv
(
s
)
# [*, N_res, H
* (P_q + P_v) * 3
]
kv_pts
=
self
.
linear_kv_points
(
s
)
# [*, N_res, H
, 2 * C_hidden
]
kv
=
kv
.
view
(
kv
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
)
)
# [*, N_res, H * (P_q + P_v), 3]
kv_pts
=
torch
.
split
(
kv_pts
,
kv_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
kv_pts
=
torch
.
stack
(
kv_pts
,
dim
=-
1
)
kv_pts
=
r
[...,
None
].
apply
(
kv_pts
)
# [*, N_res, H, C_hidden]
k
,
v
=
torch
.
split
(
kv
,
self
.
c_hidden
,
dim
=-
1
)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts
=
kv_pts
.
view
(
kv_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
-
1
,
3
))
kv_pts
=
self
.
linear_kv_points
(
s
,
r
)
# [*, N_res, H, P_q/P_v, 3]
k_pts
,
v_pts
=
torch
.
split
(
...
...
@@ -308,12 +380,12 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H]
b
=
self
.
linear_b
(
z
[
0
])
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
z
[
0
])
==
2
)
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
z
[
0
])
==
2
)
z
[
0
]
=
z
[
0
].
cpu
()
# [*, H, N_res, N_res]
if
(
is_fp16_enabled
()):
if
(
is_fp16_enabled
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
a
=
torch
.
matmul
(
permute_final_dims
(
q
.
float
(),
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
...
...
@@ -330,26 +402,29 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
if
(
inplace_safe
):
if
(
inplace_safe
):
pt_att
*=
pt_att
else
:
pt_att
=
pt_att
**
2
# [*, N_res, N_res, H, P_q]
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
)
head_weights
=
head_weights
*
math
.
sqrt
(
1.0
/
(
3
*
(
self
.
no_qk_points
*
9.0
/
2
))
)
if
(
inplace_safe
):
if
(
inplace_safe
):
pt_att
*=
head_weights
else
:
pt_att
=
pt_att
*
head_weights
# [*, N_res, N_res, H]
pt_att
=
torch
.
sum
(
pt_att
,
dim
=-
1
)
*
(
-
0.5
)
# [*, N_res, N_res]
square_mask
=
mask
.
unsqueeze
(
-
1
)
*
mask
.
unsqueeze
(
-
2
)
square_mask
=
self
.
inf
*
(
square_mask
-
1
)
...
...
@@ -357,7 +432,7 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
pt_att
=
permute_final_dims
(
pt_att
,
(
2
,
0
,
1
))
if
(
inplace_safe
):
if
(
inplace_safe
):
a
+=
pt_att
del
pt_att
a
+=
square_mask
.
unsqueeze
(
-
3
)
...
...
@@ -384,7 +459,7 @@ class InvariantPointAttention(nn.Module):
o
=
flatten_final_dims
(
o
,
2
)
# [*, H, 3, N_res, P_v]
if
(
inplace_safe
):
if
(
inplace_safe
):
v_pts
=
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))
o_pt
=
[
torch
.
matmul
(
a
,
v
.
to
(
a
.
dtype
))
...
...
@@ -411,8 +486,9 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v, 3]
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
o_pt
=
torch
.
unbind
(
o_pt
,
dim
=-
1
)
if
(
_offload_inference
):
if
(
_offload_inference
):
z
[
0
]
=
z
[
0
].
to
(
o_pt
.
device
)
# [*, N_res, H, C_z]
...
...
@@ -424,7 +500,233 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, C_s]
s
=
self
.
linear_out
(
torch
.
cat
(
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
),
o_pt_norm
,
o_pair
),
dim
=-
1
(
o
,
*
o_pt
,
o_pt_norm
,
o_pair
),
dim
=-
1
).
to
(
dtype
=
z
[
0
].
dtype
)
)
return
s
#TODO: This module follows the refactoring done in IPA for multimer. Running the regular IPA above
# in multimer mode should be equivalent, but tests do not pass unless using this version. Determine
# whether or not the increase in test error matters in practice.
class
InvariantPointAttentionMultimer
(
nn
.
Module
):
"""
Implements Algorithm 22.
"""
def
__init__
(
self
,
c_s
:
int
,
c_z
:
int
,
c_hidden
:
int
,
no_heads
:
int
,
no_qk_points
:
int
,
no_v_points
:
int
,
inf
:
float
=
1e5
,
eps
:
float
=
1e-8
,
is_multimer
:
bool
=
True
,
):
"""
Args:
c_s:
Single representation channel dimension
c_z:
Pair representation channel dimension
c_hidden:
Hidden channel dimension
no_heads:
Number of attention heads
no_qk_points:
Number of query/key points to generate
no_v_points:
Number of value points to generate
"""
super
(
InvariantPointAttentionMultimer
,
self
).
__init__
()
self
.
c_s
=
c_s
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
no_qk_points
=
no_qk_points
self
.
no_v_points
=
no_v_points
self
.
inf
=
inf
self
.
eps
=
eps
# These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default
# Lecun initialization.
hc
=
self
.
c_hidden
*
self
.
no_heads
self
.
linear_q
=
Linear
(
self
.
c_s
,
hc
,
bias
=
False
)
self
.
linear_q_points
=
PointProjection
(
self
.
c_s
,
self
.
no_qk_points
,
self
.
no_heads
,
is_multimer
=
True
)
self
.
linear_k
=
Linear
(
self
.
c_s
,
hc
,
bias
=
False
)
self
.
linear_v
=
Linear
(
self
.
c_s
,
hc
,
bias
=
False
)
self
.
linear_k_points
=
PointProjection
(
self
.
c_s
,
self
.
no_qk_points
,
self
.
no_heads
,
is_multimer
=
True
)
self
.
linear_v_points
=
PointProjection
(
self
.
c_s
,
self
.
no_v_points
,
self
.
no_heads
,
is_multimer
=
True
)
self
.
linear_b
=
Linear
(
self
.
c_z
,
self
.
no_heads
)
self
.
head_weights
=
nn
.
Parameter
(
torch
.
zeros
((
no_heads
)))
ipa_point_weights_init_
(
self
.
head_weights
)
concat_out_dim
=
self
.
no_heads
*
(
self
.
c_z
+
self
.
c_hidden
+
self
.
no_v_points
*
4
)
self
.
linear_out
=
Linear
(
concat_out_dim
,
self
.
c_s
,
init
=
"final"
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
2
)
def
forward
(
self
,
s
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
],
r
:
Union
[
Rigid
,
Rigid3Array
],
mask
:
torch
.
Tensor
,
inplace_safe
:
bool
=
False
,
_offload_inference
:
bool
=
False
,
_z_reference_list
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
r:
[*, N_res] transformation object
mask:
[*, N_res] mask
Returns:
[*, N_res, C_s] single representation update
"""
if
(
_offload_inference
and
inplace_safe
):
z
=
_z_reference_list
else
:
z
=
[
z
]
a
=
0.
point_variance
=
(
max
(
self
.
no_qk_points
,
1
)
*
9.0
/
2
)
point_weights
=
math
.
sqrt
(
1.0
/
point_variance
)
softplus
=
lambda
x
:
torch
.
logaddexp
(
x
,
torch
.
zeros_like
(
x
))
head_weights
=
softplus
(
self
.
head_weights
)
point_weights
=
point_weights
*
head_weights
#######################################
# Generate scalar and point activations
#######################################
# [*, N_res, H, P_qk]
q_pts
=
Vec3Array
.
from_array
(
self
.
linear_q_points
(
s
,
r
))
# [*, N_res, H, P_qk, 3]
k_pts
=
Vec3Array
.
from_array
(
self
.
linear_k_points
(
s
,
r
))
pt_att
=
square_euclidean_distance
(
q_pts
.
unsqueeze
(
-
3
),
k_pts
.
unsqueeze
(
-
4
),
epsilon
=
0.
)
pt_att
=
torch
.
sum
(
pt_att
*
point_weights
[...,
None
],
dim
=-
1
)
*
(
-
0.5
)
pt_att
=
pt_att
.
to
(
dtype
=
s
.
dtype
)
a
=
a
+
pt_att
scalar_variance
=
max
(
self
.
c_hidden
,
1
)
*
1.
scalar_weights
=
math
.
sqrt
(
1.0
/
scalar_variance
)
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
s
)
k
=
self
.
linear_k
(
s
)
# [*, N_res, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
=
q
*
scalar_weights
a
=
a
+
torch
.
einsum
(
'...qhc,...khc->...qkh'
,
q
,
k
)
##########################
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b
=
self
.
linear_b
(
z
[
0
])
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
z
[
0
])
==
2
)
z
[
0
]
=
z
[
0
].
cpu
()
a
=
a
+
b
# [*, N_res, N_res]
square_mask
=
mask
.
unsqueeze
(
-
1
)
*
mask
.
unsqueeze
(
-
2
)
square_mask
=
self
.
inf
*
(
square_mask
-
1
)
a
=
a
+
square_mask
.
unsqueeze
(
-
1
)
a
=
a
*
math
.
sqrt
(
1.
/
3
)
# Normalize by number of logit terms (3)
a
=
self
.
softmax
(
a
)
# [*, N_res, H * C_hidden]
v
=
self
.
linear_v
(
s
)
# [*, N_res, H, C_hidden]
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
o
=
torch
.
einsum
(
'...qkh, ...khc->...qhc'
,
a
,
v
)
# [*, N_res, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
# [*, N_res, H, P_v, 3]
v_pts
=
Vec3Array
.
from_array
(
self
.
linear_v_points
(
s
,
r
))
# [*, N_res, H, P_v]
o_pt
=
v_pts
[...,
None
,
:,
:,
:]
*
a
.
unsqueeze
(
-
1
)
o_pt
=
o_pt
.
sum
(
dim
=-
3
)
# o_pt = Vec3Array(
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].x, dim=-3),
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].y, dim=-3),
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].z, dim=-3),
# )
# [*, N_res, H * P_v, 3]
o_pt
=
o_pt
.
reshape
(
o_pt
.
shape
[:
-
2
]
+
(
-
1
,))
# [*, N_res, H, P_v]
o_pt
=
r
[...,
None
].
apply_inverse_to_point
(
o_pt
)
o_pt_flat
=
[
o_pt
.
x
,
o_pt
.
y
,
o_pt
.
z
]
o_pt_flat
=
[
x
.
to
(
dtype
=
a
.
dtype
)
for
x
in
o_pt_flat
]
# [*, N_res, H * P_v]
o_pt_norm
=
o_pt
.
norm
(
epsilon
=
1e-8
)
if
(
_offload_inference
):
z
[
0
]
=
z
[
0
].
to
(
o_pt
.
x
.
device
)
o_pair
=
torch
.
einsum
(
'...ijh, ...ijc->...ihc'
,
a
,
z
[
0
].
to
(
dtype
=
a
.
dtype
))
# [*, N_res, H * C_z]
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
# [*, N_res, C_s]
s
=
self
.
linear_out
(
torch
.
cat
(
(
o
,
*
o_pt_flat
,
o_pt_norm
,
o_pair
),
dim
=-
1
).
to
(
dtype
=
z
[
0
].
dtype
)
)
...
...
@@ -530,6 +832,7 @@ class StructureModule(nn.Module):
trans_scale_factor
,
epsilon
,
inf
,
is_multimer
=
False
,
**
kwargs
,
):
"""
...
...
@@ -583,6 +886,7 @@ class StructureModule(nn.Module):
self
.
trans_scale_factor
=
trans_scale_factor
self
.
epsilon
=
epsilon
self
.
inf
=
inf
self
.
is_multimer
=
is_multimer
# Buffers to be lazily initialized later
# self.default_frames
...
...
@@ -595,7 +899,8 @@ class StructureModule(nn.Module):
self
.
linear_in
=
Linear
(
self
.
c_s
,
self
.
c_s
)
self
.
ipa
=
InvariantPointAttention
(
ipa
=
InvariantPointAttention
if
not
self
.
is_multimer
else
InvariantPointAttentionMultimer
self
.
ipa
=
ipa
(
self
.
c_s
,
self
.
c_z
,
self
.
c_ipa
,
...
...
@@ -604,6 +909,7 @@ class StructureModule(nn.Module):
self
.
no_v_points
,
inf
=
self
.
inf
,
eps
=
self
.
epsilon
,
is_multimer
=
self
.
is_multimer
,
)
self
.
ipa_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
...
...
@@ -615,6 +921,9 @@ class StructureModule(nn.Module):
self
.
dropout_rate
,
)
if
self
.
is_multimer
:
self
.
bb_update
=
QuatRigid
(
self
.
c_s
,
full_quat
=
False
)
else
:
self
.
bb_update
=
BackboneUpdate
(
self
.
c_s
)
self
.
angle_resnet
=
AngleResnet
(
...
...
@@ -625,7 +934,7 @@ class StructureModule(nn.Module):
self
.
epsilon
,
)
def
forward
(
def
_
forward
_monomer
(
self
,
evoformer_output_dict
,
aatype
,
...
...
@@ -661,8 +970,8 @@ class StructureModule(nn.Module):
z
=
self
.
layer_norm_z
(
evoformer_output_dict
[
"pair"
])
z_reference_list
=
None
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
evoformer_output_dict
[
"pair"
])
==
2
)
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
evoformer_output_dict
[
"pair"
])
==
2
)
evoformer_output_dict
[
"pair"
]
=
evoformer_output_dict
[
"pair"
].
cpu
()
z_reference_list
=
[
z
]
z
=
None
...
...
@@ -744,7 +1053,102 @@ class StructureModule(nn.Module):
del
z
,
z_reference_list
if
(
_offload_inference
):
if
(
_offload_inference
):
evoformer_output_dict
[
"pair"
]
=
(
evoformer_output_dict
[
"pair"
].
to
(
s
.
device
)
)
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
[
"single"
]
=
s
return
outputs
def
_forward_multimer
(
self
,
evoformer_output_dict
,
aatype
,
mask
=
None
,
inplace_safe
=
False
,
_offload_inference
=
False
,
):
s
=
evoformer_output_dict
[
"single"
]
if
mask
is
None
:
# [*, N]
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
# [*, N, C_s]
s
=
self
.
layer_norm_s
(
s
)
# [*, N, N, C_z]
z
=
self
.
layer_norm_z
(
evoformer_output_dict
[
"pair"
])
z_reference_list
=
None
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
evoformer_output_dict
[
"pair"
])
==
2
)
evoformer_output_dict
[
"pair"
]
=
evoformer_output_dict
[
"pair"
].
cpu
()
z_reference_list
=
[
z
]
z
=
None
# [*, N, C_s]
s_initial
=
s
s
=
self
.
linear_in
(
s
)
# [*, N]
rigids
=
Rigid3Array
.
identity
(
s
.
shape
[:
-
1
],
s
.
device
,
)
outputs
=
[]
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
,
_z_reference_list
=
z_reference_list
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
# [*, N]
rigids
=
rigids
@
self
.
bb_update
(
s
)
# [*, N, 7, 2]
unnormalized_angles
,
angles
=
self
.
angle_resnet
(
s
,
s_initial
)
all_frames_to_global
=
self
.
torsion_angles_to_frames
(
rigids
.
scale_translation
(
self
.
trans_scale_factor
),
angles
,
aatype
,
)
pred_xyz
=
self
.
frames_and_literature_positions_to_atom14_pos
(
all_frames_to_global
,
aatype
,
)
preds
=
{
"frames"
:
rigids
.
scale_translation
(
self
.
trans_scale_factor
).
to_tensor
(),
"sidechain_frames"
:
all_frames_to_global
.
to_tensor_4x4
(),
"unnormalized_angles"
:
unnormalized_angles
,
"angles"
:
angles
,
"positions"
:
pred_xyz
,
}
preds
=
{
k
:
v
.
to
(
dtype
=
s
.
dtype
)
for
k
,
v
in
preds
.
items
()}
outputs
.
append
(
preds
)
rigids
=
rigids
.
stop_rot_gradient
()
del
z
,
z_reference_list
if
(
_offload_inference
):
evoformer_output_dict
[
"pair"
]
=
(
evoformer_output_dict
[
"pair"
].
to
(
s
.
device
)
)
...
...
@@ -754,6 +1158,34 @@ class StructureModule(nn.Module):
return
outputs
def
forward
(
self
,
evoformer_output_dict
,
aatype
,
mask
=
None
,
inplace_safe
=
False
,
_offload_inference
=
False
,
):
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
if
(
self
.
is_multimer
):
outputs
=
self
.
_forward_multimer
(
evoformer_output_dict
,
aatype
,
mask
,
inplace_safe
,
_offload_inference
)
else
:
outputs
=
self
.
_forward_monomer
(
evoformer_output_dict
,
aatype
,
mask
,
inplace_safe
,
_offload_inference
)
return
outputs
def
_init_residue_constants
(
self
,
float_dtype
,
device
):
if
not
hasattr
(
self
,
"default_frames"
):
self
.
register_buffer
(
...
...
@@ -809,7 +1241,7 @@ class StructureModule(nn.Module):
self
,
r
,
f
# [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
r
.
get_rots
().
dtype
,
r
.
get_rots
()
.
device
)
self
.
_init_residue_constants
(
r
.
dtype
,
r
.
device
)
return
frames_and_literature_positions_to_atom14_pos
(
r
,
f
,
...
...
openfold/model/template.py
View file @
bb3f51e5
...
...
@@ -33,6 +33,8 @@ from openfold.model.triangular_attention import (
from
openfold.model.triangular_multiplicative_update
import
(
TriangleMultiplicationOutgoing
,
TriangleMultiplicationIncoming
,
FusedTriangleMultiplicationOutgoing
,
FusedTriangleMultiplicationIncoming
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
from
openfold.utils.chunk_utils
import
(
...
...
@@ -54,6 +56,7 @@ class TemplatePointwiseAttention(nn.Module):
"""
Implements Algorithm 17.
"""
def
__init__
(
self
,
c_t
,
c_z
,
c_hidden
,
no_heads
,
inf
,
**
kwargs
):
"""
Args:
...
...
@@ -100,7 +103,6 @@ class TemplatePointwiseAttention(nn.Module):
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
def
forward
(
self
,
t
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
...
...
@@ -153,6 +155,8 @@ class TemplatePairStackBlock(nn.Module):
no_heads
:
int
,
pair_transition_n
:
int
,
dropout_rate
:
float
,
tri_mul_first
:
bool
,
fuse_projection_weights
:
bool
,
inf
:
float
,
**
kwargs
,
):
...
...
@@ -165,6 +169,7 @@ class TemplatePairStackBlock(nn.Module):
self
.
pair_transition_n
=
pair_transition_n
self
.
dropout_rate
=
dropout_rate
self
.
inf
=
inf
self
.
tri_mul_first
=
tri_mul_first
self
.
dropout_row
=
DropoutRowwise
(
self
.
dropout_rate
)
self
.
dropout_col
=
DropoutColumnwise
(
self
.
dropout_rate
)
...
...
@@ -182,6 +187,16 @@ class TemplatePairStackBlock(nn.Module):
inf
=
inf
,
)
if
fuse_projection_weights
:
self
.
tri_mul_out
=
FusedTriangleMultiplicationOutgoing
(
self
.
c_t
,
self
.
c_hidden_tri_mul
,
)
self
.
tri_mul_in
=
FusedTriangleMultiplicationIncoming
(
self
.
c_t
,
self
.
c_hidden_tri_mul
,
)
else
:
self
.
tri_mul_out
=
TriangleMultiplicationOutgoing
(
self
.
c_t
,
self
.
c_hidden_tri_mul
,
...
...
@@ -196,30 +211,13 @@ class TemplatePairStackBlock(nn.Module):
self
.
pair_transition_n
,
)
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
):
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
single_templates
=
[
t
.
unsqueeze
(
-
4
)
for
t
in
torch
.
unbind
(
z
,
dim
=-
4
)
]
single_templates_masks
=
[
m
.
unsqueeze
(
-
3
)
for
m
in
torch
.
unbind
(
mask
,
dim
=-
3
)
]
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
def
tri_att_start_end
(
self
,
single
:
torch
.
Tensor
,
_attn_chunk_size
:
Optional
[
int
],
single_mask
:
torch
.
Tensor
,
use_deepspeed_evo_attention
:
bool
,
use_lma
:
bool
,
inplace_safe
:
bool
):
single
=
add
(
single
,
self
.
dropout_row
(
self
.
tri_att_start
(
...
...
@@ -248,13 +246,19 @@ class TemplatePairStackBlock(nn.Module):
inplace_safe
,
)
return
single
def
tri_mul_out_in
(
self
,
single
:
torch
.
Tensor
,
single_mask
:
torch
.
Tensor
,
inplace_safe
:
bool
):
tmu_update
=
self
.
tri_mul_out
(
single
,
mask
=
single_mask
,
inplace_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
if
(
not
inplace_safe
)
:
if
not
inplace_safe
:
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
single
=
tmu_update
...
...
@@ -267,13 +271,59 @@ class TemplatePairStackBlock(nn.Module):
inplace_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
if
(
not
inplace_safe
)
:
if
not
inplace_safe
:
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
single
=
tmu_update
del
tmu_update
return
single
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
):
if
_attn_chunk_size
is
None
:
_attn_chunk_size
=
chunk_size
single_templates
=
[
t
.
unsqueeze
(
-
4
)
for
t
in
torch
.
unbind
(
z
,
dim
=-
4
)
]
single_templates_masks
=
[
m
.
unsqueeze
(
-
3
)
for
m
in
torch
.
unbind
(
mask
,
dim
=-
3
)
]
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
if
self
.
tri_mul_first
:
single
=
self
.
tri_att_start_end
(
single
=
self
.
tri_mul_out_in
(
single
=
single
,
single_mask
=
single_mask
,
inplace_safe
=
inplace_safe
),
_attn_chunk_size
=
_attn_chunk_size
,
single_mask
=
single_mask
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
)
else
:
single
=
self
.
tri_mul_out_in
(
single
=
self
.
tri_att_start_end
(
single
=
single
,
_attn_chunk_size
=
_attn_chunk_size
,
single_mask
=
single_mask
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
),
single_mask
=
single_mask
,
inplace_safe
=
inplace_safe
)
single
=
add
(
single
,
self
.
pair_transition
(
single
,
...
...
@@ -283,10 +333,10 @@ class TemplatePairStackBlock(nn.Module):
inplace_safe
,
)
if
(
not
inplace_safe
)
:
if
not
inplace_safe
:
single_templates
[
i
]
=
single
if
(
not
inplace_safe
)
:
if
not
inplace_safe
:
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
return
z
...
...
@@ -296,6 +346,7 @@ class TemplatePairStack(nn.Module):
"""
Implements Algorithm 16.
"""
def
__init__
(
self
,
c_t
,
...
...
@@ -305,6 +356,8 @@ class TemplatePairStack(nn.Module):
no_heads
,
pair_transition_n
,
dropout_rate
,
tri_mul_first
,
fuse_projection_weights
,
blocks_per_ckpt
,
tune_chunk_size
:
bool
=
False
,
inf
=
1e9
,
...
...
@@ -341,6 +394,8 @@ class TemplatePairStack(nn.Module):
no_heads
=
no_heads
,
pair_transition_n
=
pair_transition_n
,
dropout_rate
=
dropout_rate
,
tri_mul_first
=
tri_mul_first
,
fuse_projection_weights
=
fuse_projection_weights
,
inf
=
inf
,
)
self
.
blocks
.
append
(
block
)
...
...
@@ -349,7 +404,7 @@ class TemplatePairStack(nn.Module):
self
.
tune_chunk_size
=
tune_chunk_size
self
.
chunk_size_tuner
=
None
if
(
tune_chunk_size
)
:
if
tune_chunk_size
:
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
forward
(
...
...
@@ -371,7 +426,7 @@ class TemplatePairStack(nn.Module):
Returns:
[*, N_templ, N_res, N_res, C_t] template embedding update
"""
if
(
mask
.
shape
[
-
3
]
==
1
)
:
if
mask
.
shape
[
-
3
]
==
1
:
expand_idx
=
list
(
mask
.
shape
)
expand_idx
[
-
3
]
=
t
.
shape
[
-
4
]
mask
=
mask
.
expand
(
*
expand_idx
)
...
...
@@ -389,8 +444,8 @@ class TemplatePairStack(nn.Module):
for
b
in
self
.
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
)
:
assert
(
not
self
.
training
)
if
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
:
assert
(
not
self
.
training
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
t
.
clone
(),),
...
...
@@ -478,7 +533,7 @@ def embed_templates_offload(
_mask_trans
=
model
.
config
.
_mask_trans
,
)
assert
(
sys
.
getrefcount
(
t
)
==
2
)
assert
(
sys
.
getrefcount
(
t
)
==
2
)
pair_embeds_cpu
.
append
(
t
.
cpu
())
...
...
@@ -504,7 +559,7 @@ def embed_templates_offload(
del
pair_chunks
if
(
inplace_safe
)
:
if
inplace_safe
:
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
else
:
t
*=
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
...
...
@@ -516,9 +571,9 @@ def embed_templates_offload(
)
# [*, N, C_m]
a
=
model
.
template_
a
ngle_embedder
(
template_angle_feat
)
a
=
model
.
template_
si
ngle_embedder
(
template_angle_feat
)
ret
[
"template_
a
ngle_embedding"
]
=
a
ret
[
"template_
si
ngle_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
t
})
...
...
@@ -605,19 +660,19 @@ def embed_templates_average(
)
denom
=
math
.
ceil
(
n_templ
/
templ_group_size
)
if
(
inplace_safe
)
:
if
inplace_safe
:
t
/=
denom
else
:
t
=
t
/
denom
if
(
inplace_safe
)
:
if
inplace_safe
:
out_tensor
+=
t
else
:
out_tensor
=
out_tensor
+
t
del
t
if
(
inplace_safe
)
:
if
inplace_safe
:
out_tensor
*=
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
else
:
out_tensor
=
out_tensor
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
...
...
@@ -629,9 +684,9 @@ def embed_templates_average(
)
# [*, N, C_m]
a
=
model
.
template_
a
ngle_embedder
(
template_angle_feat
)
a
=
model
.
template_
si
ngle_embedder
(
template_angle_feat
)
ret
[
"template_
a
ngle_embedding"
]
=
a
ret
[
"template_
si
ngle_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
out_tensor
})
...
...
openfold/model/triangular_multiplicative_update.py
View file @
bb3f51e5
...
...
@@ -15,6 +15,7 @@
from
functools
import
partialmethod
from
typing
import
Optional
from
abc
import
ABC
,
abstractmethod
import
torch
import
torch.nn
as
nn
...
...
@@ -25,11 +26,12 @@ from openfold.utils.precision_utils import is_fp16_enabled
from
openfold.utils.tensor_utils
import
add
,
permute_final_dims
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
class
Base
TriangleMultiplicativeUpdate
(
nn
.
Module
,
ABC
):
"""
Implements Algorithms 11 and 12.
"""
def
__init__
(
self
,
c_z
,
c_hidden
,
_outgoing
=
True
):
@
abstractmethod
def
__init__
(
self
,
c_z
,
c_hidden
,
_outgoing
):
"""
Args:
c_z:
...
...
@@ -37,15 +39,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
c:
Hidden channel dimension
"""
super
(
TriangleMultiplicativeUpdate
,
self
).
__init__
()
super
(
Base
TriangleMultiplicativeUpdate
,
self
).
__init__
()
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
_outgoing
=
_outgoing
self
.
linear_a_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_a_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_b_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_b_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_g
=
Linear
(
self
.
c_z
,
self
.
c_z
,
init
=
"gating"
)
self
.
linear_z
=
Linear
(
self
.
c_hidden
,
self
.
c_z
,
init
=
"final"
)
...
...
@@ -84,6 +82,46 @@ class TriangleMultiplicativeUpdate(nn.Module):
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
@
abstractmethod
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace_safe
:
bool
=
False
,
_add_with_inplace
:
bool
=
False
)
->
torch
.
Tensor
:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
pass
class
TriangleMultiplicativeUpdate
(
BaseTriangleMultiplicativeUpdate
):
"""
Implements Algorithms 11 and 12.
"""
def
__init__
(
self
,
c_z
,
c_hidden
,
_outgoing
=
True
):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super
(
TriangleMultiplicativeUpdate
,
self
).
__init__
(
c_z
=
c_z
,
c_hidden
=
c_hidden
,
_outgoing
=
_outgoing
)
self
.
linear_a_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_a_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_b_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_b_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
def
_inference_forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -397,7 +435,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
# reduced-precision modes
a_std
=
a
.
std
()
b_std
=
b
.
std
()
if
(
a_std
!=
0.
and
b_std
!=
0.
):
if
(
is_fp16_enabled
()
and
a_std
!=
0.
and
b_std
!=
0.
):
a
=
a
/
a
.
std
()
b
=
b
/
b
.
std
()
...
...
@@ -428,3 +466,152 @@ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
Implements Algorithm 12.
"""
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
False
)
class
FusedTriangleMultiplicativeUpdate
(
BaseTriangleMultiplicativeUpdate
):
"""
Implements Algorithms 11 and 12.
"""
def
__init__
(
self
,
c_z
,
c_hidden
,
_outgoing
=
True
):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super
(
FusedTriangleMultiplicativeUpdate
,
self
).
__init__
(
c_z
=
c_z
,
c_hidden
=
c_hidden
,
_outgoing
=
_outgoing
)
self
.
linear_ab_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
*
2
)
self
.
linear_ab_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
*
2
,
init
=
"gating"
)
def
_inference_forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_inplace_chunk_size
:
Optional
[
int
]
=
None
,
with_add
:
bool
=
True
,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
with_add:
If True, z is overwritten with (z + update). Otherwise, it is
overwritten with (update).
Returns:
A reference to the overwritten z
"""
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
def
compute_projection_helper
(
pair
,
mask
):
p
=
self
.
linear_ab_g
(
pair
)
p
.
sigmoid_
()
p
*=
self
.
linear_ab_p
(
pair
)
p
*=
mask
return
p
def
compute_projection
(
pair
,
mask
):
p
=
compute_projection_helper
(
pair
,
mask
)
left
=
p
[...,
:
self
.
c_hidden
]
right
=
p
[...,
self
.
c_hidden
:]
return
left
,
right
z_norm_in
=
self
.
layer_norm_in
(
z
)
a
,
b
=
compute_projection
(
z_norm_in
,
mask
)
x
=
self
.
_combine_projections
(
a
,
b
,
_inplace_chunk_size
=
_inplace_chunk_size
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
linear_g
(
z_norm_in
)
g
.
sigmoid_
()
x
*=
g
if
(
with_add
):
z
+=
x
else
:
z
=
x
return
z
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace_safe
:
bool
=
False
,
_add_with_inplace
:
bool
=
False
,
_inplace_chunk_size
:
Optional
[
int
]
=
256
)
->
torch
.
Tensor
:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if
(
inplace_safe
):
x
=
self
.
_inference_forward
(
z
,
mask
,
_inplace_chunk_size
=
_inplace_chunk_size
,
with_add
=
_add_with_inplace
,
)
return
x
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
z
=
self
.
layer_norm_in
(
z
)
ab
=
mask
ab
=
ab
*
self
.
sigmoid
(
self
.
linear_ab_g
(
z
))
ab
=
ab
*
self
.
linear_ab_p
(
z
)
a
=
ab
[...,
:
self
.
c_hidden
]
b
=
ab
[...,
self
.
c_hidden
:]
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
a_std
=
a
.
std
()
b_std
=
b
.
std
()
if
(
is_fp16_enabled
()
and
a_std
!=
0.
and
b_std
!=
0.
):
a
=
a
/
a
.
std
()
b
=
b
/
b
.
std
()
if
(
is_fp16_enabled
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
_combine_projections
(
a
.
float
(),
b
.
float
())
else
:
x
=
self
.
_combine_projections
(
a
,
b
)
del
a
,
b
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
x
=
x
*
g
return
x
class
FusedTriangleMultiplicationOutgoing
(
FusedTriangleMultiplicativeUpdate
):
"""
Implements Algorithm 11.
"""
__init__
=
partialmethod
(
FusedTriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
True
)
class
FusedTriangleMultiplicationIncoming
(
FusedTriangleMultiplicativeUpdate
):
"""
Implements Algorithm 12.
"""
__init__
=
partialmethod
(
FusedTriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
False
)
openfold/np/protein.py
View file @
bb3f51e5
...
...
@@ -36,6 +36,11 @@ FeatureDict = Mapping[str, np.ndarray]
ModelOutput
=
Mapping
[
str
,
Any
]
# Is a nested dict.
PICO_TO_ANGSTROM
=
0.01
PDB_CHAIN_IDS
=
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS
=
len
(
PDB_CHAIN_IDS
)
assert
(
PDB_MAX_CHAINS
==
62
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Protein
:
"""Protein structure representation."""
...
...
@@ -73,6 +78,13 @@ class Protein:
# Chain corresponding to each parent
parents_chain_index
:
Optional
[
Sequence
[
int
]]
=
None
def
__post_init__
(
self
):
if
(
len
(
np
.
unique
(
self
.
chain_index
))
>
PDB_MAX_CHAINS
):
raise
ValueError
(
f
"Cannot build an instance with more than
{
PDB_MAX_CHAINS
}
"
"chains because these cannot be written to PDB format"
)
def
from_pdb_string
(
pdb_str
:
str
,
chain_id
:
Optional
[
str
]
=
None
)
->
Protein
:
"""Takes a PDB string and constructs a Protein object.
...
...
@@ -108,6 +120,7 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
for
chain
in
model
:
if
(
chain_id
is
not
None
and
chain
.
id
!=
chain_id
):
continue
for
res
in
chain
:
if
res
.
id
[
2
]
!=
" "
:
raise
ValueError
(
...
...
@@ -132,6 +145,7 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
if
np
.
sum
(
mask
)
<
0.5
:
# If no known atom positions are reported for the residue then skip it.
continue
aatype
.
append
(
restype_idx
)
atom_positions
.
append
(
pos
)
atom_mask
.
append
(
mask
)
...
...
@@ -224,6 +238,14 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
def
_chain_end
(
atom_index
,
end_resname
,
chain_name
,
residue_index
)
->
str
:
chain_end
=
'TER'
return
(
f
'
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
end_resname
:
>
3
}
'
f
'
{
chain_name
:
>
1
}{
residue_index
:
>
4
}
'
)
def
get_pdb_headers
(
prot
:
Protein
,
chain_id
:
int
=
0
)
->
Sequence
[
str
]:
pdb_headers
=
[]
...
...
@@ -316,21 +338,46 @@ def to_pdb(prot: Protein) -> str:
atom_positions
=
prot
.
atom_positions
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
b_factors
=
prot
.
b_factors
chain_index
=
prot
.
chain_index
chain_index
=
prot
.
chain_index
.
astype
(
np
.
int32
)
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
raise
ValueError
(
"Invalid aatypes."
)
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids
=
{}
for
i
in
np
.
unique
(
chain_index
):
# np.unique gives sorted output.
if
i
>=
PDB_MAX_CHAINS
:
raise
ValueError
(
f
"The PDB format supports at most
{
PDB_MAX_CHAINS
}
chains."
)
chain_ids
[
i
]
=
PDB_CHAIN_IDS
[
i
]
headers
=
get_pdb_headers
(
prot
)
if
(
len
(
headers
)
>
0
):
if
(
len
(
headers
)
>
0
):
pdb_lines
.
extend
(
headers
)
pdb_lines
.
append
(
"MODEL 1"
)
n
=
aatype
.
shape
[
0
]
atom_index
=
1
last_chain_index
=
chain_index
[
0
]
prev_chain_index
=
0
chain_tags
=
string
.
ascii_uppercase
# Add all atom sites.
for
i
in
range
(
n
):
for
i
in
range
(
aatype
.
shape
[
0
]):
# Close the previous chain if in a multichain PDB.
if
last_chain_index
!=
chain_index
[
i
]:
pdb_lines
.
append
(
_chain_end
(
atom_index
,
res_1to3
(
aatype
[
i
-
1
]),
chain_ids
[
chain_index
[
i
-
1
]],
residue_index
[
i
-
1
]
)
)
last_chain_index
=
chain_index
[
i
]
atom_index
+=
1
# Atom index increases at the TER symbol.
res_name_3
=
res_1to3
(
aatype
[
i
])
for
atom_name
,
pos
,
mask
,
b_factor
in
zip
(
atom_types
,
atom_positions
[
i
],
atom_mask
[
i
],
b_factors
[
i
]
...
...
@@ -355,6 +402,8 @@ def to_pdb(prot: Protein) -> str:
# PDB is a columnar format, every space matters here!
atom_line
=
(
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
#TODO: check this refactor, chose main branch version
#f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
f
"
{
res_name_3
:
>
3
}
{
chain_tag
:
>
1
}
"
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
...
...
@@ -386,9 +435,12 @@ def to_pdb(prot: Protein) -> str:
# each new chain.
pdb_lines
.
extend
(
get_pdb_headers
(
prot
,
prev_chain_index
))
pdb_lines
.
append
(
"ENDMDL"
)
pdb_lines
.
append
(
"END"
)
pdb_lines
.
append
(
""
)
return
"
\n
"
.
join
(
pdb_lines
)
# Pad all lines to 80 characters
pdb_lines
=
[
line
.
ljust
(
80
)
for
line
in
pdb_lines
]
return
'
\n
'
.
join
(
pdb_lines
)
+
'
\n
'
# Add terminating newline.
def
to_modelcif
(
prot
:
Protein
)
->
str
:
...
...
@@ -539,7 +591,7 @@ def from_prediction(
features
:
FeatureDict
,
result
:
ModelOutput
,
b_factors
:
Optional
[
np
.
ndarray
]
=
None
,
chain_index
:
Optional
[
np
.
ndarray
]
=
Non
e
,
remove_leading_feature_dimension
:
bool
=
Tru
e
,
remark
:
Optional
[
str
]
=
None
,
parents
:
Optional
[
Sequence
[
str
]]
=
None
,
parents_chain_index
:
Optional
[
Sequence
[
int
]]
=
None
...
...
@@ -550,20 +602,32 @@ def from_prediction(
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
chain_index: (Optional) Chain indices for multi-chain predictions
remark: (Optional) Remark about the prediction
parents: (Optional) List of template names
Returns:
A protein instance.
"""
def
_maybe_remove_leading_dim
(
arr
:
np
.
ndarray
)
->
np
.
ndarray
:
return
arr
[
0
]
if
remove_leading_feature_dimension
else
arr
if
'asym_id'
in
features
:
chain_index
=
_maybe_remove_leading_dim
(
features
[
"asym_id"
])
-
1
else
:
chain_index
=
np
.
zeros_like
(
_maybe_remove_leading_dim
(
features
[
"aatype"
])
)
if
b_factors
is
None
:
b_factors
=
np
.
zeros_like
(
result
[
"final_atom_mask"
])
return
Protein
(
aatype
=
features
[
"aatype"
],
aatype
=
_maybe_remove_leading_dim
(
features
[
"aatype"
]
)
,
atom_positions
=
result
[
"final_atom_positions"
],
atom_mask
=
result
[
"final_atom_mask"
],
residue_index
=
features
[
"residue_index"
]
+
1
,
residue_index
=
_maybe_remove_leading_dim
(
features
[
"residue_index"
]
)
+
1
,
b_factors
=
b_factors
,
chain_index
=
chain_index
,
remark
=
remark
,
...
...
openfold/np/relax/amber_minimize.py
View file @
bb3f51e5
...
...
@@ -563,60 +563,3 @@ def run_pipeline(
)
iteration
+=
1
return
ret
def
get_initial_energies
(
pdb_strs
:
Sequence
[
str
],
stiffness
:
float
=
0.0
,
restraint_set
:
str
=
"non_hydrogen"
,
exclude_residues
:
Optional
[
Sequence
[
int
]]
=
None
,
):
"""Returns initial potential energies for a sequence of PDBs.
Assumes the input PDBs are ready for minimization, and all have the same
topology.
Allows time to be saved by not pdbfixing / rebuilding the system.
Args:
pdb_strs: List of PDB strings.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
restraint_set: Which atom types to restrain.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
A list of initial energies in the same order as pdb_strs.
"""
exclude_residues
=
exclude_residues
or
[]
openmm_pdbs
=
[
openmm_app
.
PDBFile
(
PdbStructure
(
io
.
StringIO
(
p
)))
for
p
in
pdb_strs
]
force_field
=
openmm_app
.
ForceField
(
"amber99sb.xml"
)
system
=
force_field
.
createSystem
(
openmm_pdbs
[
0
].
topology
,
constraints
=
openmm_app
.
HBonds
)
stiffness
=
stiffness
*
ENERGY
/
(
LENGTH
**
2
)
if
stiffness
>
0
*
ENERGY
/
(
LENGTH
**
2
):
_add_restraints
(
system
,
openmm_pdbs
[
0
],
stiffness
,
restraint_set
,
exclude_residues
)
simulation
=
openmm_app
.
Simulation
(
openmm_pdbs
[
0
].
topology
,
system
,
openmm
.
LangevinIntegrator
(
0
,
0.01
,
0.0
),
openmm
.
Platform
.
getPlatformByName
(
"CPU"
),
)
energies
=
[]
for
pdb
in
openmm_pdbs
:
try
:
simulation
.
context
.
setPositions
(
pdb
.
positions
)
state
=
simulation
.
context
.
getState
(
getEnergy
=
True
)
energies
.
append
(
state
.
getPotentialEnergy
().
value_in_unit
(
ENERGY
))
except
Exception
as
e
:
# pylint: disable=broad-except
logging
.
error
(
"Error getting initial energy, returning large value %s"
,
e
)
energies
.
append
(
unit
.
Quantity
(
1e20
,
ENERGY
))
return
energies
openfold/np/relax/utils.py
View file @
bb3f51e5
...
...
@@ -79,7 +79,7 @@ def assert_equal_nonterminal_atom_types(
"""Checks that pre- and post-minimized proteins have same atom set."""
# Ignore any terminal OXT atoms which may have been added by minimization.
oxt
=
residue_constants
.
atom_order
[
"OXT"
]
no_oxt_mask
=
np
.
ones
(
shape
=
atom_mask
.
shape
,
dtype
=
np
.
bool
)
no_oxt_mask
=
np
.
ones
(
shape
=
atom_mask
.
shape
,
dtype
=
bool
)
no_oxt_mask
[...,
oxt
]
=
False
np
.
testing
.
assert_almost_equal
(
ref_atom_mask
[
no_oxt_mask
],
atom_mask
[
no_oxt_mask
]
...
...
openfold/np/residue_constants.py
View file @
bb3f51e5
...
...
@@ -17,14 +17,13 @@
import
collections
import
functools
import
os
from
typing
import
Mapping
,
List
,
Tuple
from
importlib
import
resources
import
numpy
as
np
import
tree
# Internal import (35fd).
# Distance from one CA to next CA [trans configuration: omega = 180].
ca_ca
=
3.80209737096
...
...
@@ -450,9 +449,9 @@ def load_stereo_chemical_props() -> Tuple[
("residue_virtual_bonds").
Returns:
residue_bonds:
d
ict that maps resname
-
-> list of Bond tuples
residue_virtual_bonds:
d
ict that maps resname
-
-> list of Bond tuples
residue_bond_angles:
d
ict that maps resname
-
-> list of BondAngle tuples
residue_bonds:
D
ict that maps resname -> list of Bond tuples
residue_virtual_bonds:
D
ict that maps resname -> list of Bond tuples
residue_bond_angles:
D
ict that maps resname -> list of BondAngle tuples
"""
# TODO: this file should be downloaded in a setup script
stereo_chemical_props
=
resources
.
read_text
(
"openfold.resources"
,
"stereo_chemical_props.txt"
)
...
...
@@ -1310,3 +1309,179 @@ def aatype_to_str_sequence(aatype):
restypes_with_x
[
aatype
[
i
]]
for
i
in
range
(
len
(
aatype
))
])
### ALPHAFOLD MULTIMER STUFF ###
def
_make_chi_atom_indices
():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
restypes
:
residue_name
=
restype_1to3
[
residue_name
]
residue_chi_angles
=
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
atom_order
[
atom
]
for
atom
in
chi_angle
])
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
return
np
.
array
(
chi_atom_indices
)
def
_make_renaming_matrices
():
"""Matrices to map atoms to symmetry partners in ambiguous case."""
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative groundtruth coordinates where the naming is swapped
restype_3
=
[
restype_1to3
[
res
]
for
res
in
restypes
]
restype_3
+=
[
'UNK'
]
# Matrices for renaming ambiguous atoms.
all_matrices
=
{
res
:
np
.
eye
(
14
,
dtype
=
np
.
float32
)
for
res
in
restype_3
}
for
resname
,
swap
in
residue_atom_renaming_swaps
.
items
():
correspondences
=
np
.
arange
(
14
)
for
source_atom_swap
,
target_atom_swap
in
swap
.
items
():
source_index
=
restype_name_to_atom14_names
[
resname
].
index
(
source_atom_swap
)
target_index
=
restype_name_to_atom14_names
[
resname
].
index
(
target_atom_swap
)
correspondences
[
source_index
]
=
target_index
correspondences
[
target_index
]
=
source_index
renaming_matrix
=
np
.
zeros
((
14
,
14
),
dtype
=
np
.
float32
)
for
index
,
correspondence
in
enumerate
(
correspondences
):
renaming_matrix
[
index
,
correspondence
]
=
1.
all_matrices
[
resname
]
=
renaming_matrix
.
astype
(
np
.
float32
)
renaming_matrices
=
np
.
stack
([
all_matrices
[
restype
]
for
restype
in
restype_3
])
return
renaming_matrices
def
_make_restype_atom37_mask
():
"""Mask of which atoms are present for which residue type in atom37."""
# create the corresponding mask
restype_atom37_mask
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
float32
)
for
restype
,
restype_letter
in
enumerate
(
restypes
):
restype_name
=
restype_1to3
[
restype_letter
]
atom_names
=
residue_atoms
[
restype_name
]
for
atom_name
in
atom_names
:
atom_type
=
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
return
restype_atom37_mask
def
_make_restype_atom14_mask
():
"""Mask of which atoms are present for which residue type in atom14."""
restype_atom14_mask
=
[]
for
rt
in
restypes
:
atom_names
=
restype_name_to_atom14_names
[
restype_1to3
[
rt
]]
restype_atom14_mask
.
append
([(
1.
if
name
else
0.
)
for
name
in
atom_names
])
restype_atom14_mask
.
append
([
0.
]
*
14
)
restype_atom14_mask
=
np
.
array
(
restype_atom14_mask
,
dtype
=
np
.
float32
)
return
restype_atom14_mask
def
_make_restype_atom37_to_atom14
():
"""Map from atom37 to atom14 per residue type."""
restype_atom37_to_atom14
=
[]
# mapping (restype, atom37) --> atom14
for
rt
in
restypes
:
atom_names
=
restype_name_to_atom14_names
[
restype_1to3
[
rt
]]
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
restype_atom37_to_atom14
.
append
([
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
atom_types
])
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom37_to_atom14
=
np
.
array
(
restype_atom37_to_atom14
,
dtype
=
np
.
int32
)
return
restype_atom37_to_atom14
def
_make_restype_atom14_to_atom37
():
"""Map from atom14 to atom37 per residue type."""
restype_atom14_to_atom37
=
[]
# mapping (restype, atom14) --> atom37
for
rt
in
restypes
:
atom_names
=
restype_name_to_atom14_names
[
restype_1to3
[
rt
]]
restype_atom14_to_atom37
.
append
([
(
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
restype_atom14_to_atom37
=
np
.
array
(
restype_atom14_to_atom37
,
dtype
=
np
.
int32
)
return
restype_atom14_to_atom37
def
_make_restype_atom14_is_ambiguous
():
"""Mask which atoms are ambiguous in atom14."""
# create an ambiguous atoms mask. shape: (21, 14)
restype_atom14_is_ambiguous
=
np
.
zeros
((
21
,
14
),
dtype
=
np
.
float32
)
for
resname
,
swap
in
residue_atom_renaming_swaps
.
items
():
for
atom_name1
,
atom_name2
in
swap
.
items
():
restype
=
restype_order
[
restype_3to1
[
resname
]]
atom_idx1
=
restype_name_to_atom14_names
[
resname
].
index
(
atom_name1
)
atom_idx2
=
restype_name_to_atom14_names
[
resname
].
index
(
atom_name2
)
restype_atom14_is_ambiguous
[
restype
,
atom_idx1
]
=
1
restype_atom14_is_ambiguous
[
restype
,
atom_idx2
]
=
1
return
restype_atom14_is_ambiguous
def
_make_restype_rigidgroup_base_atom37_idx
():
"""Create Map from rigidgroups to atom37 indices."""
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
base_atom_names
=
np
.
full
([
21
,
8
,
3
],
''
,
dtype
=
object
)
# 0: backbone frame
base_atom_names
[:,
0
,
:]
=
[
'C'
,
'CA'
,
'N'
]
# 3: 'psi-group'
base_atom_names
[:,
3
,
:]
=
[
'CA'
,
'C'
,
'O'
]
# 4,5,6,7: 'chi1,2,3,4-group'
for
restype
,
restype_letter
in
enumerate
(
restypes
):
resname
=
restype_1to3
[
restype_letter
]
for
chi_idx
in
range
(
4
):
if
chi_angles_mask
[
restype
][
chi_idx
]:
atom_names
=
chi_angles_atoms
[
resname
][
chi_idx
]
base_atom_names
[
restype
,
chi_idx
+
4
,
:]
=
atom_names
[
1
:]
# Translate atom names into atom37 indices.
lookuptable
=
atom_order
.
copy
()
lookuptable
[
''
]
=
0
restype_rigidgroup_base_atom37_idx
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])(
base_atom_names
)
return
restype_rigidgroup_base_atom37_idx
CHI_ATOM_INDICES
=
_make_chi_atom_indices
()
RENAMING_MATRICES
=
_make_renaming_matrices
()
RESTYPE_ATOM14_TO_ATOM37
=
_make_restype_atom14_to_atom37
()
RESTYPE_ATOM37_TO_ATOM14
=
_make_restype_atom37_to_atom14
()
RESTYPE_ATOM37_MASK
=
_make_restype_atom37_mask
()
RESTYPE_ATOM14_MASK
=
_make_restype_atom14_mask
()
RESTYPE_ATOM14_IS_AMBIGUOUS
=
_make_restype_atom14_is_ambiguous
()
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX
=
_make_restype_rigidgroup_base_atom37_idx
()
# Create mask for existing rigid groups.
RESTYPE_RIGIDGROUP_MASK
=
np
.
zeros
([
21
,
8
],
dtype
=
np
.
float32
)
RESTYPE_RIGIDGROUP_MASK
[:,
0
]
=
1
RESTYPE_RIGIDGROUP_MASK
[:,
3
]
=
1
RESTYPE_RIGIDGROUP_MASK
[:
20
,
4
:]
=
chi_angles_mask
openfold/utils/all_atom_multimer.py
0 → 100644
View file @
bb3f51e5
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ops for all atom representations."""
from
functools
import
partial
from
typing
import
Dict
,
Text
,
Tuple
import
torch
from
openfold.np
import
residue_constants
as
rc
from
openfold.utils
import
geometry
,
tensor_utils
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
import
numpy
as
np
def
squared_difference
(
x
,
y
):
return
np
.
square
(
x
-
y
)
def
get_rc_tensor
(
rc_np
,
aatype
):
return
torch
.
tensor
(
rc_np
,
device
=
aatype
.
device
)[
aatype
]
def
atom14_to_atom37
(
atom14_data
:
torch
.
Tensor
,
# (*, N, 14, ...)
aatype
:
torch
.
Tensor
# (*, N)
)
->
Tuple
:
# (*, N, 37, ...)
"""Convert atom14 to atom37 representation."""
idx_atom37_to_atom14
=
get_rc_tensor
(
rc
.
RESTYPE_ATOM37_TO_ATOM14
,
aatype
).
long
()
no_batch_dims
=
len
(
aatype
.
shape
)
-
1
atom37_data
=
tensor_utils
.
batched_gather
(
atom14_data
,
idx_atom37_to_atom14
,
dim
=
no_batch_dims
+
1
,
no_batch_dims
=
no_batch_dims
+
1
)
atom37_mask
=
get_rc_tensor
(
rc
.
RESTYPE_ATOM37_MASK
,
aatype
)
if
len
(
atom14_data
.
shape
)
==
no_batch_dims
+
2
:
atom37_data
*=
atom37_mask
elif
len
(
atom14_data
.
shape
)
==
no_batch_dims
+
3
:
atom37_data
*=
atom37_mask
[...,
None
].
to
(
dtype
=
atom37_data
.
dtype
)
else
:
raise
ValueError
(
"Incorrectly shaped data"
)
return
atom37_data
,
atom37_mask
def
atom37_to_atom14
(
aatype
,
all_atom_pos
,
all_atom_mask
):
"""Convert Atom37 positions to Atom14 positions."""
residx_atom14_to_atom37
=
get_rc_tensor
(
rc
.
RESTYPE_ATOM14_TO_ATOM37
,
aatype
)
no_batch_dims
=
len
(
aatype
.
shape
)
atom14_mask
=
tensor_utils
.
batched_gather
(
all_atom_mask
,
residx_atom14_to_atom37
,
dim
=
no_batch_dims
+
1
,
no_batch_dims
=
no_batch_dims
+
1
,
).
to
(
all_atom_pos
.
dtype
)
# create a mask for known groundtruth positions
atom14_mask
*=
get_rc_tensor
(
rc
.
RESTYPE_ATOM14_MASK
,
aatype
)
# gather the groundtruth positions
atom14_positions
=
tensor_utils
.
batched_gather
(
all_atom_pos
,
residx_atom14_to_atom37
,
dim
=
no_batch_dims
+
1
,
no_batch_dims
=
no_batch_dims
+
1
,
),
atom14_positions
=
atom14_mask
*
atom14_positions
return
atom14_positions
,
atom14_mask
def
get_alt_atom14
(
aatype
,
positions
:
torch
.
Tensor
,
mask
):
"""Get alternative atom14 positions."""
# pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14)
renaming_transform
=
get_rc_tensor
(
rc
.
RENAMING_MATRICES
,
aatype
)
alternative_positions
=
torch
.
sum
(
positions
[...,
None
,
:]
*
renaming_transform
[...,
None
],
dim
=-
2
)
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position)
alternative_mask
=
torch
.
sum
(
mask
[...,
None
]
*
renaming_transform
,
dim
=-
2
)
return
alternative_positions
,
alternative_mask
def
atom37_to_frames
(
aatype
:
torch
.
Tensor
,
# (...)
all_atom_positions
:
torch
.
Tensor
,
# (..., 37)
all_atom_mask
:
torch
.
Tensor
,
# (..., 37)
)
->
Dict
[
Text
,
torch
.
Tensor
]:
"""Computes the frames for the up to 8 rigid groups for each residue."""
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
no_batch_dims
=
len
(
aatype
.
shape
)
-
1
# Compute the gather indices for all residues in the chain.
# shape (N, 8, 3)
residx_rigidgroup_base_atom37_idx
=
get_rc_tensor
(
rc
.
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX
,
aatype
)
# Gather the base atom positions for each rigid group.
base_atom_pos
=
tensor_utils
.
batched_gather
(
all_atom_positions
,
residx_rigidgroup_base_atom37_idx
,
dim
=
no_batch_dims
+
1
,
batch_dims
=
no_batch_dims
+
1
,
)
# Compute the Rigids.
point_on_neg_x_axis
=
base_atom_pos
[...,
:,
:,
0
]
origin
=
base_atom_pos
[...,
:,
:,
1
]
point_on_xy_plane
=
base_atom_pos
[...,
:,
:,
2
]
gt_rotation
=
geometry
.
Rot3Array
.
from_two_vectors
(
origin
-
point_on_neg_x_axis
,
point_on_xy_plane
-
origin
)
gt_frames
=
geometry
.
Rigid3Array
(
gt_rotation
,
origin
)
# Compute a mask whether the group exists.
# (N, 8)
group_exists
=
get_rc_tensor
(
rc
.
RESTYPE_RIGIDGROUP_MASK
,
aatype
)
# Compute a mask whether ground truth exists for the group
gt_atoms_exist
=
tensor_utils
.
batched_gather
(
# shape (N, 8, 3)
all_atom_mask
.
to
(
dtype
=
all_atom_positions
.
dtype
),
residx_rigidgroup_base_atom37_idx
,
batch_dims
=
no_batch_dims
+
1
,
)
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)
*
group_exists
# (N, 8)
# Adapt backbone frame to old convention (mirror x-axis and z-axis).
rots
=
np
.
tile
(
np
.
eye
(
3
,
dtype
=
all_atom_positions
.
dtype
),
[
8
,
1
,
1
])
rots
[
0
,
0
,
0
]
=
-
1
rots
[
0
,
2
,
2
]
=
-
1
gt_frames
=
gt_frames
.
compose_rotation
(
geometry
.
Rot3Array
.
from_array
(
torch
.
tensor
(
rots
,
device
=
aatype
.
device
)
)
)
# The frames for ambiguous rigid groups are just rotated by 180 degree around
# the x-axis. The ambiguous group is always the last chi-group.
restype_rigidgroup_is_ambiguous
=
np
.
zeros
([
21
,
8
],
dtype
=
all_atom_positions
.
dtype
)
restype_rigidgroup_rots
=
np
.
tile
(
np
.
eye
(
3
,
dtype
=
all_atom_positions
.
dtype
),
[
21
,
8
,
1
,
1
]
)
for
resname
,
_
in
rc
.
residue_atom_renaming_swaps
.
items
():
restype
=
rc
.
restype_order
[
rc
.
restype_3to1
[
resname
]
]
chi_idx
=
int
(
sum
(
rc
.
chi_angles_mask
[
restype
])
-
1
)
restype_rigidgroup_is_ambiguous
[
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_rots
[
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
restype_rigidgroup_rots
[
restype
,
chi_idx
+
4
,
2
,
2
]
=
-
1
# Gather the ambiguity information for each residue.
residx_rigidgroup_is_ambiguous
=
torch
.
tensor
(
restype_rigidgroup_is_ambiguous
,
device
=
aatype
.
device
,
)[
aatype
]
ambiguity_rot
=
torch
.
tensor
(
restype_rigidgroup_rots
,
device
=
aatype
.
device
,
)[
aatype
]
ambiguity_rot
=
geometry
.
Rot3Array
.
from_array
(
torch
.
Tensor
(
ambiguity_rot
,
device
=
aatype
.
device
)
)
# Create the alternative ground truth frames.
alt_gt_frames
=
gt_frames
.
compose_rotation
(
ambiguity_rot
)
fix_shape
=
lambda
x
:
x
.
reshape
(
x
.
shape
[:
-
2
]
+
(
8
,))
# reshape back to original residue layout
gt_frames
=
fix_shape
(
gt_frames
)
gt_exists
=
fix_shape
(
gt_exists
)
group_exists
=
fix_shape
(
group_exists
)
residx_rigidgroup_is_ambiguous
=
fix_shape
(
residx_rigidgroup_is_ambiguous
)
alt_gt_frames
=
fix_shape
(
alt_gt_frames
)
return
{
'rigidgroups_gt_frames'
:
gt_frames
,
# Rigid (..., 8)
'rigidgroups_gt_exists'
:
gt_exists
,
# (..., 8)
'rigidgroups_group_exists'
:
group_exists
,
# (..., 8)
'rigidgroups_group_is_ambiguous'
:
residx_rigidgroup_is_ambiguous
,
# (..., 8)
'rigidgroups_alt_gt_frames'
:
alt_gt_frames
,
# Rigid (..., 8)
}
def
torsion_angles_to_frames
(
aatype
:
torch
.
Tensor
,
# (N)
backb_to_global
:
geometry
.
Rigid3Array
,
# (N)
torsion_angles_sin_cos
:
torch
.
Tensor
# (N, 7, 2)
)
->
geometry
.
Rigid3Array
:
# (N, 8)
"""Compute rigid group frames from torsion angles."""
# Gather the default frames for all rigid groups.
# geometry.Rigid3Array with shape (N, 8)
m
=
get_rc_tensor
(
rc
.
restype_rigid_group_default_frame
,
aatype
)
default_frames
=
geometry
.
Rigid3Array
.
from_array4x4
(
m
)
# Create the rotation matrices according to the given angles (each frame is
# defined such that its rotation is around the x-axis).
sin_angles
=
torsion_angles_sin_cos
[...,
0
]
cos_angles
=
torsion_angles_sin_cos
[...,
1
]
# insert zero rotation for backbone group.
num_residues
=
aatype
.
shape
[
-
1
]
sin_angles
=
torch
.
cat
(
[
torch
.
zeros_like
(
aatype
).
unsqueeze
(
dim
=-
1
),
sin_angles
,
],
dim
=-
1
)
cos_angles
=
torch
.
cat
(
[
torch
.
ones_like
(
aatype
).
unsqueeze
(
dim
=-
1
),
cos_angles
],
dim
=-
1
)
zeros
=
torch
.
zeros_like
(
sin_angles
)
ones
=
torch
.
ones_like
(
sin_angles
)
# all_rots are geometry.Rot3Array with shape (..., N, 8)
all_rots
=
geometry
.
Rot3Array
(
ones
,
zeros
,
zeros
,
zeros
,
cos_angles
,
-
sin_angles
,
zeros
,
sin_angles
,
cos_angles
)
# Apply rotations to the frames.
all_frames
=
default_frames
.
compose_rotation
(
all_rots
)
# chi2, chi3, and chi4 frames do not transform to the backbone frame but to
# the previous frame. So chain them up accordingly.
chi1_frame_to_backb
=
all_frames
[...,
4
]
chi2_frame_to_backb
=
chi1_frame_to_backb
@
all_frames
[...,
5
]
chi3_frame_to_backb
=
chi2_frame_to_backb
@
all_frames
[...,
6
]
chi4_frame_to_backb
=
chi3_frame_to_backb
@
all_frames
[...,
7
]
all_frames_to_backb
=
Rigid3Array
.
cat
(
[
all_frames
[...,
0
:
5
],
chi2_frame_to_backb
[...,
None
],
chi3_frame_to_backb
[...,
None
],
chi4_frame_to_backb
[...,
None
]
],
dim
=-
1
)
# Create the global frames.
# shape (N, 8)
all_frames_to_global
=
backb_to_global
[...,
None
]
@
all_frames_to_backb
return
all_frames_to_global
def
frames_and_literature_positions_to_atom14_pos
(
aatype
:
torch
.
Tensor
,
# (*, N)
all_frames_to_global
:
geometry
.
Rigid3Array
# (N, 8)
)
->
geometry
.
Vec3Array
:
# (*, N, 14)
"""Put atom literature positions (atom14 encoding) in each rigid group."""
# Pick the appropriate transform for every atom.
residx_to_group_idx
=
get_rc_tensor
(
rc
.
restype_atom14_to_rigid_group
,
aatype
)
group_mask
=
torch
.
nn
.
functional
.
one_hot
(
residx_to_group_idx
,
num_classes
=
8
)
# shape (*, N, 14, 8)
# geometry.Rigid3Array with shape (N, 14)
map_atoms_to_global
=
all_frames_to_global
[...,
None
,
:]
*
group_mask
map_atoms_to_global
=
map_atoms_to_global
.
map_tensor_fn
(
partial
(
torch
.
sum
,
dim
=-
1
)
)
# Gather the literature atom positions for each residue.
# geometry.Vec3Array with shape (N, 14)
lit_positions
=
geometry
.
Vec3Array
.
from_array
(
get_rc_tensor
(
rc
.
restype_atom14_rigid_group_positions
,
aatype
)
)
# Transform each atom from its local frame to the global frame.
# geometry.Vec3Array with shape (N, 14)
pred_positions
=
map_atoms_to_global
.
apply_to_point
(
lit_positions
)
# Mask out non-existing atoms.
mask
=
get_rc_tensor
(
rc
.
restype_atom14_mask
,
aatype
)
pred_positions
=
pred_positions
*
mask
return
pred_positions
def
extreme_ca_ca_distance_violations
(
positions
:
geometry
.
Vec3Array
,
# (N, 37(14))
mask
:
torch
.
Tensor
,
# (N, 37(14))
residue_index
:
torch
.
Tensor
,
# (N)
max_angstrom_tolerance
=
1.5
,
eps
:
float
=
1e-6
)
->
torch
.
Tensor
:
"""Counts residues whose Ca is a large distance from its neighbor."""
this_ca_pos
=
positions
[...,
:
-
1
,
1
]
# (N - 1,)
this_ca_mask
=
mask
[...,
:
-
1
,
1
]
# (N - 1)
next_ca_pos
=
positions
[...,
1
:,
1
]
# (N - 1,)
next_ca_mask
=
mask
[...,
1
:,
1
]
# (N - 1)
has_no_gap_mask
=
(
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
).
astype
(
positions
.
x
.
dtype
)
ca_ca_distance
=
geometry
.
euclidean_distance
(
this_ca_pos
,
next_ca_pos
,
eps
)
violations
=
(
ca_ca_distance
-
rc
.
ca_ca
)
>
max_angstrom_tolerance
mask
=
this_ca_mask
*
next_ca_mask
*
has_no_gap_mask
return
tensor_utils
.
masked_mean
(
mask
=
mask
,
value
=
violations
,
dim
=-
1
)
def
get_chi_atom_indices
(
device
:
torch
.
device
):
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
rc
.
restypes
:
residue_name
=
rc
.
restype_1to3
[
residue_name
]
residue_chi_angles
=
rc
.
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
rc
.
atom_order
[
atom
]
for
atom
in
chi_angle
]
)
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
return
torch
.
tensor
(
chi_atom_indices
,
device
=
device
)
def
compute_chi_angles
(
positions
:
geometry
.
Vec3Array
,
mask
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
):
"""Computes the chi angles given all atom positions and the amino acid type.
Args:
positions: A Vec3Array of shape
[num_res, rc.atom_type_num], with positions of
atoms needed to calculate chi angles. Supports up to 1 batch dimension.
mask: An optional tensor of shape
[num_res, rc.atom_type_num] that masks which atom
positions are set for each residue. If given, then the chi mask will be
set to 1 for a chi angle only if the amino acid has that chi angle and all
the chi atoms needed to calculate that chi angle are set. If not given
(set to None), the chi mask will be set to 1 for a chi angle if the amino
acid has that chi angle and whether the actual atoms needed to calculate
it were set will be ignored.
aatype: A tensor of shape [num_res] with amino acid type integer
code (0 to 21). Supports up to 1 batch dimension.
Returns:
A tuple of tensors (chi_angles, mask), where both have shape
[num_res, 4]. The mask masks out unused chi angles for amino acid
types that have less than 4 chi angles. If atom_positions_mask is set, the
chi mask will also mask out uncomputable chi angles.
"""
# Don't assert on the num_res and batch dimensions as they might be unknown.
assert
positions
.
shape
[
-
1
]
==
rc
.
atom_type_num
assert
mask
.
shape
[
-
1
]
==
rc
.
atom_type_num
no_batch_dims
=
len
(
aatype
.
shape
)
-
1
# Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
chi_atom_indices
=
get_chi_atom_indices
(
aatype
.
device
)
# DISCREPANCY: DeepMind doesn't remove the gaps here. I don't know why
# theirs works.
aatype_gapless
=
torch
.
clamp
(
aatype
,
max
=
20
)
# Select atoms to compute chis. Shape: [*, num_res, chis=4, atoms=4].
atom_indices
=
chi_atom_indices
[
aatype_gapless
]
# Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3].
chi_angle_atoms
=
positions
.
map_tensor_fn
(
partial
(
tensor_utils
.
batched_gather
,
inds
=
atom_indices
,
dim
=-
1
,
no_batch_dims
=
no_batch_dims
+
1
)
)
a
,
b
,
c
,
d
=
[
chi_angle_atoms
[...,
i
]
for
i
in
range
(
4
)]
chi_angles
=
geometry
.
dihedral_angle
(
a
,
b
,
c
,
d
)
# Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4].
chi_angles_mask
=
list
(
rc
.
chi_angles_mask
)
chi_angles_mask
.
append
([
0.0
,
0.0
,
0.0
,
0.0
])
chi_angles_mask
=
torch
.
tensor
(
chi_angles_mask
,
device
=
aatype
.
device
)
# Compute the chi angle mask. Shape [num_res, chis=4].
chi_mask
=
chi_angles_mask
[
aatype_gapless
]
# The chi_mask is set to 1 only when all necessary chi angle atoms were set.
# Gather the chi angle atoms mask. Shape: [num_res, chis=4, atoms=4].
chi_angle_atoms_mask
=
tensor_utils
.
batched_gather
(
mask
,
atom_indices
,
dim
=-
1
,
no_batch_dims
=
no_batch_dims
+
1
)
# Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4].
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
)
chi_mask
=
chi_mask
*
chi_angle_atoms_mask
.
to
(
chi_angles
.
dtype
)
return
chi_angles
,
chi_mask
def
make_transform_from_reference
(
a_xyz
:
geometry
.
Vec3Array
,
b_xyz
:
geometry
.
Vec3Array
,
c_xyz
:
geometry
.
Vec3Array
)
->
geometry
.
Rigid3Array
:
"""Returns rotation and translation matrices to convert from reference.
Note that this method does not take care of symmetries. If you provide the
coordinates in the non-standard way, the A atom will end up in the negative
y-axis rather than in the positive y-axis. You need to take care of such
cases in your code.
Args:
a_xyz: A Vec3Array.
b_xyz: A Vec3Array.
c_xyz: A Vec3Array.
Returns:
A Rigid3Array which, when applied to coordinates in a canonicalized
reference frame, will give coordinates approximately equal
the original coordinates (in the global frame).
"""
rotation
=
geometry
.
Rot3Array
.
from_two_vectors
(
c_xyz
-
b_xyz
,
a_xyz
-
b_xyz
)
return
geometry
.
Rigid3Array
(
rotation
,
b_xyz
)
def
make_backbone_affine
(
positions
:
geometry
.
Vec3Array
,
mask
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
)
->
Tuple
[
geometry
.
Rigid3Array
,
torch
.
Tensor
]:
a
=
rc
.
atom_order
[
'N'
]
b
=
rc
.
atom_order
[
'CA'
]
c
=
rc
.
atom_order
[
'C'
]
rigid_mask
=
(
mask
[...,
a
]
*
mask
[...,
b
]
*
mask
[...,
c
])
rigid
=
make_transform_from_reference
(
a_xyz
=
positions
[...,
a
],
b_xyz
=
positions
[...,
b
],
c_xyz
=
positions
[...,
c
],
)
return
rigid
,
rigid_mask
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment