Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
e699d7d2
Commit
e699d7d2
authored
Feb 22, 2022
by
Gustaf Ahdritz
Browse files
Start implementing Multimer
parent
61d004a2
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1379 additions
and
83 deletions
+1379
-83
openfold/data/mmcif_parsing.py
openfold/data/mmcif_parsing.py
+3
-1
openfold/data/msa_identifiers.py
openfold/data/msa_identifiers.py
+92
-0
openfold/data/msa_pairing.py
openfold/data/msa_pairing.py
+626
-0
openfold/data/multimer_feature_processing.py
openfold/data/multimer_feature_processing.py
+240
-0
openfold/data/parsers.py
openfold/data/parsers.py
+273
-17
openfold/np/protein.py
openfold/np/protein.py
+121
-55
openfold/np/residue_constants.py
openfold/np/residue_constants.py
+4
-3
openfold/utils/loss.py
openfold/utils/loss.py
+20
-7
No files found.
openfold/data/mmcif_parsing.py
View file @
e699d7d2
...
...
@@ -16,6 +16,7 @@
"""Parses the mmCIF file format."""
import
collections
import
dataclasses
import
functools
import
io
import
json
import
logging
...
...
@@ -173,6 +174,7 @@ def mmcif_loop_to_dict(
return
{
entry
[
index
]:
entry
for
entry
in
entries
}
@
functools
.
lru_cache
(
16
,
typed
=
False
)
def
parse
(
*
,
file_id
:
str
,
mmcif_string
:
str
,
catch_all_errors
:
bool
=
True
)
->
ParsingResult
:
...
...
@@ -346,7 +348,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
raw_resolution
=
parsed_info
[
res_key
][
0
]
header
[
"resolution"
]
=
float
(
raw_resolution
)
except
ValueError
:
logging
.
info
(
logging
.
debug
(
"Invalid resolution format: %s"
,
parsed_info
[
res_key
]
)
...
...
openfold/data/msa_identifiers.py
0 → 100644
View file @
e699d7d2
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for extracting identifiers from MSA sequence descriptions."""
import
dataclasses
import
re
from
typing
import
Optional
# Sequences coming from UniProtKB database come in the
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
_UNIPROT_PATTERN
=
re
.
compile
(
r
"""
^
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
(?:tr|sp)
\|
# A primary accession number of the UniProtKB entry.
(?P<AccessionIdentifier>[A-Za-z0-9]{6,10})
# Occasionally there is a _0 or _1 isoform suffix, which we ignore.
(?:_\d)?
\|
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
# protein ID code.
(?:[A-Za-z0-9]+)
_
# A mnemonic species identification code.
(?P<SpeciesIdentifier>([A-Za-z0-9]){1,5})
# Small BFD uses a final value after an underscore, which we ignore.
(?:_\d+)?
$
"""
,
re
.
VERBOSE
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Identifiers
:
uniprot_accession_id
:
str
=
''
species_id
:
str
=
''
def
_parse_sequence_identifier
(
msa_sequence_identifier
:
str
)
->
Identifiers
:
"""Gets accession id and species from an msa sequence identifier.
The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
Args:
msa_sequence_identifier: a sequence identifier.
Returns:
An `Identifiers` instance with a uniprot_accession_id and species_id. These
can be empty in the case where no identifier was found.
"""
matches
=
re
.
search
(
_UNIPROT_PATTERN
,
msa_sequence_identifier
.
strip
())
if
matches
:
return
Identifiers
(
uniprot_accession_id
=
matches
.
group
(
'AccessionIdentifier'
),
species_id
=
matches
.
group
(
'SpeciesIdentifier'
))
return
Identifiers
()
def
_extract_sequence_identifier
(
description
:
str
)
->
Optional
[
str
]:
"""Extracts sequence identifier from description. Returns None if no match."""
split_description
=
description
.
split
()
if
split_description
:
return
split_description
[
0
].
partition
(
'/'
)[
0
]
else
:
return
None
def
get_identifiers
(
description
:
str
)
->
Identifiers
:
"""Computes extra MSA features from the description."""
sequence_identifier
=
_extract_sequence_identifier
(
description
)
if
sequence_identifier
is
None
:
return
Identifiers
()
else
:
return
_parse_sequence_identifier
(
sequence_identifier
)
openfold/data/msa_pairing.py
0 → 100644
View file @
e699d7d2
This diff is collapsed.
Click to expand it.
openfold/data/multimer_feature_processing.py
0 → 100644
View file @
e699d7d2
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature processing logic for multimer data pipeline."""
from
typing
import
Iterable
,
MutableMapping
,
List
from
openfold.data
import
msa_pairing
from
openfold.np
import
residue_constants
import
numpy
as
np
# TODO: Move this into the config
REQUIRED_FEATURES
=
frozenset
({
'aatype'
,
'all_atom_mask'
,
'all_atom_positions'
,
'all_chains_entity_ids'
,
'all_crops_all_chains_mask'
,
'all_crops_all_chains_positions'
,
'all_crops_all_chains_residue_ids'
,
'assembly_num_chains'
,
'asym_id'
,
'bert_mask'
,
'cluster_bias_mask'
,
'deletion_matrix'
,
'deletion_mean'
,
'entity_id'
,
'entity_mask'
,
'mem_peak'
,
'msa'
,
'msa_mask'
,
'num_alignments'
,
'num_templates'
,
'queue_size'
,
'residue_index'
,
'resolution'
,
'seq_length'
,
'seq_mask'
,
'sym_id'
,
'template_aatype'
,
'template_all_atom_mask'
,
'template_all_atom_positions'
})
MAX_TEMPLATES
=
4
MSA_CROP_SIZE
=
2048
def
_is_homomer_or_monomer
(
chains
:
Iterable
[
Mapping
[
str
,
np
.
ndarray
]])
->
bool
:
"""Checks if a list of chains represents a homomer/monomer example."""
# Note that an entity_id of 0 indicates padding.
num_unique_chains
=
len
(
np
.
unique
(
np
.
concatenate
(
[
np
.
unique
(
chain
[
'entity_id'
][
chain
[
'entity_id'
]
>
0
])
for
chain
in
chains
])))
return
num_unique_chains
==
1
def
pair_and_merge
(
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]],
is_prokaryote
:
bool
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Runs processing on features to augment, pair and merge.
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
is_prokaryote: Whether the target complex is from a prokaryotic or
eukaryotic organism.
Returns:
A dictionary of features.
"""
process_unmerged_features
(
all_chain_features
)
np_chains_list
=
list
(
all_chain_features
.
values
())
pair_msa_sequences
=
not
_is_homomer_or_monomer
(
np_chains_list
)
if
pair_msa_sequences
:
np_chains_list
=
msa_pairing
.
create_paired_features
(
chains
=
np_chains_list
,
prokaryotic
=
is_prokaryote
)
np_chains_list
=
msa_pairing
.
deduplicate_unpaired_sequences
(
np_chains_list
)
np_chains_list
=
crop_chains
(
np_chains_list
,
msa_crop_size
=
MSA_CROP_SIZE
,
pair_msa_sequences
=
pair_msa_sequences
,
max_templates
=
MAX_TEMPLATES
)
np_example
=
msa_pairing
.
merge_chain_features
(
np_chains_list
=
np_chains_list
,
pair_msa_sequences
=
pair_msa_sequences
,
max_templates
=
MAX_TEMPLATES
)
np_example
=
process_final
(
np_example
)
return
np_example
def
crop_chains
(
chains_list
:
List
[
Mapping
[
str
,
np
.
ndarray
]],
msa_crop_size
:
int
,
pair_msa_sequences
:
bool
,
max_templates
:
int
)
->
List
[
Mapping
[
str
,
np
.
ndarray
]]:
"""Crops the MSAs for a set of chains.
Args:
chains_list: A list of chains to be cropped.
msa_crop_size: The total number of sequences to crop from the MSA.
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
max_templates: The maximum templates to use per chain.
Returns:
The chains cropped.
"""
# Apply the cropping.
cropped_chains
=
[]
for
chain
in
chains_list
:
cropped_chain
=
_crop_single_chain
(
chain
,
msa_crop_size
=
msa_crop_size
,
pair_msa_sequences
=
pair_msa_sequences
,
max_templates
=
max_templates
)
cropped_chains
.
append
(
cropped_chain
)
return
cropped_chains
def
_crop_single_chain
(
chain
:
Mapping
[
str
,
np
.
ndarray
],
msa_crop_size
:
int
,
pair_msa_sequences
:
bool
,
max_templates
:
int
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Crops msa sequences to `msa_crop_size`."""
msa_size
=
chain
[
'num_alignments'
]
if
pair_msa_sequences
:
msa_size_all_seq
=
chain
[
'num_alignments_all_seq'
]
msa_crop_size_all_seq
=
np
.
minimum
(
msa_size_all_seq
,
msa_crop_size
//
2
)
# We reduce the number of un-paired sequences, by the number of times a
# sequence from this chain's MSA is included in the paired MSA. This keeps
# the MSA size for each chain roughly constant.
msa_all_seq
=
chain
[
'msa_all_seq'
][:
msa_crop_size_all_seq
,
:]
num_non_gapped_pairs
=
np
.
sum
(
np
.
any
(
msa_all_seq
!=
msa_pairing
.
MSA_GAP_IDX
,
axis
=
1
))
num_non_gapped_pairs
=
np
.
minimum
(
num_non_gapped_pairs
,
msa_crop_size_all_seq
)
# Restrict the unpaired crop size so that paired+unpaired sequences do not
# exceed msa_seqs_per_chain for each chain.
max_msa_crop_size
=
np
.
maximum
(
msa_crop_size
-
num_non_gapped_pairs
,
0
)
msa_crop_size
=
np
.
minimum
(
msa_size
,
max_msa_crop_size
)
else
:
msa_crop_size
=
np
.
minimum
(
msa_size
,
msa_crop_size
)
include_templates
=
'template_aatype'
in
chain
and
max_templates
if
include_templates
:
num_templates
=
chain
[
'template_aatype'
].
shape
[
0
]
templates_crop_size
=
np
.
minimum
(
num_templates
,
max_templates
)
for
k
in
chain
:
k_split
=
k
.
split
(
'_all_seq'
)[
0
]
if
k_split
in
msa_pairing
.
TEMPLATE_FEATURES
:
chain
[
k
]
=
chain
[
k
][:
templates_crop_size
,
:]
elif
k_split
in
msa_pairing
.
MSA_FEATURES
:
if
'_all_seq'
in
k
and
pair_msa_sequences
:
chain
[
k
]
=
chain
[
k
][:
msa_crop_size_all_seq
,
:]
else
:
chain
[
k
]
=
chain
[
k
][:
msa_crop_size
,
:]
chain
[
'num_alignments'
]
=
np
.
asarray
(
msa_crop_size
,
dtype
=
np
.
int32
)
if
include_templates
:
chain
[
'num_templates'
]
=
np
.
asarray
(
templates_crop_size
,
dtype
=
np
.
int32
)
if
pair_msa_sequences
:
chain
[
'num_alignments_all_seq'
]
=
np
.
asarray
(
msa_crop_size_all_seq
,
dtype
=
np
.
int32
)
return
chain
def
process_final
(
np_example
:
Mapping
[
str
,
np
.
ndarray
]
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Final processing steps in data pipeline, after merging and pairing."""
np_example
=
_correct_msa_restypes
(
np_example
)
np_example
=
_make_seq_mask
(
np_example
)
np_example
=
_make_msa_mask
(
np_example
)
np_example
=
_filter_features
(
np_example
)
return
np_example
def
_correct_msa_restypes
(
np_example
):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
np_example
[
'msa'
]
=
np
.
take
(
new_order_list
,
np_example
[
'msa'
],
axis
=
0
)
np_example
[
'msa'
]
=
np_example
[
'msa'
].
astype
(
np
.
int32
)
return
np_example
def
_make_seq_mask
(
np_example
):
np_example
[
'seq_mask'
]
=
(
np_example
[
'entity_id'
]
>
0
).
astype
(
np
.
float32
)
return
np_example
def
_make_msa_mask
(
np_example
):
"""Mask features are all ones, but will later be zero-padded."""
np_example
[
'msa_mask'
]
=
np
.
ones_like
(
np_example
[
'msa'
],
dtype
=
np
.
float32
)
seq_mask
=
(
np_example
[
'entity_id'
]
>
0
).
astype
(
np
.
float32
)
np_example
[
'msa_mask'
]
*=
seq_mask
[
None
]
return
np_example
def
_filter_features
(
np_example
:
Mapping
[
str
,
np
.
ndarray
]
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Filters features of example to only those requested."""
return
{
k
:
v
for
(
k
,
v
)
in
np_example
.
items
()
if
k
in
REQUIRED_FEATURES
}
def
process_unmerged_features
(
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]]):
"""Postprocessing stage for per-chain features before merging."""
num_chains
=
len
(
all_chain_features
)
for
chain_features
in
all_chain_features
.
values
():
# Convert deletion matrices to float.
chain_features
[
'deletion_matrix'
]
=
np
.
asarray
(
chain_features
.
pop
(
'deletion_matrix_int'
),
dtype
=
np
.
float32
)
if
'deletion_matrix_int_all_seq'
in
chain_features
:
chain_features
[
'deletion_matrix_all_seq'
]
=
np
.
asarray
(
chain_features
.
pop
(
'deletion_matrix_int_all_seq'
),
dtype
=
np
.
float32
)
chain_features
[
'deletion_mean'
]
=
np
.
mean
(
chain_features
[
'deletion_matrix'
],
axis
=
0
)
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask
=
residue_constants
.
STANDARD_ATOM_MASK
[
chain_features
[
'aatype'
]]
chain_features
[
'all_atom_mask'
]
=
all_atom_mask
chain_features
[
'all_atom_positions'
]
=
np
.
zeros
(
list
(
all_atom_mask
.
shape
)
+
[
3
])
# Add assembly_num_chains.
chain_features
[
'assembly_num_chains'
]
=
np
.
asarray
(
num_chains
)
# Add entity_mask.
for
chain_features
in
all_chain_features
.
values
():
chain_features
[
'entity_mask'
]
=
(
chain_features
[
'entity_id'
]
!=
0
).
astype
(
np
.
int32
)
openfold/data/parsers.py
View file @
e699d7d2
...
...
@@ -18,12 +18,41 @@ import collections
import
dataclasses
import
re
import
string
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
,
Set
DeletionMatrix
=
Sequence
[
Sequence
[
int
]]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Msa
:
"""Class representing a parsed MSA file"""
sequences
:
Sequence
[
str
]
deletion_matrix
:
DeletionMatrix
descriptions
:
Sequence
[
str
]
def
__post_init__
(
self
):
if
(
not
(
len
(
self
.
sequences
)
==
len
(
self
.
deletion_matrix
)
==
len
(
self
.
descriptions
)
)):
raise
ValueError
(
"All fields for an MSA must have the same length"
)
def
__len__
(
self
):
return
len
(
self
.
sequences
)
def
truncate
(
self
,
max_seqs
:
int
):
return
Msa
(
sequences
=
self
.
sequences
[:
max_seqs
],
deletion_matrix
=
self
.
deletion_matrix
[:
max_seqs
],
descriptions
=
self
.
descriptions
[:
max_seqs
],
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
TemplateHit
:
"""Class representing a template hit."""
...
...
@@ -31,7 +60,7 @@ class TemplateHit:
index
:
int
name
:
str
aligned_cols
:
int
sum_probs
:
float
sum_probs
:
Optional
[
float
]
query
:
str
hit_sequence
:
str
indices_query
:
List
[
int
]
...
...
@@ -67,9 +96,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return
sequences
,
descriptions
def
parse_stockholm
(
stockholm_string
:
str
,
)
->
Tuple
[
Sequence
[
str
],
DeletionMatrix
,
Sequence
[
str
]]:
def
parse_stockholm
(
stockholm_string
:
str
)
->
Msa
:
"""Parses sequences and deletion matrix from stockholm format alignment.
Args:
...
...
@@ -124,10 +151,14 @@ def parse_stockholm(
deletion_count
=
0
deletion_matrix
.
append
(
deletion_vec
)
return
msa
,
deletion_matrix
,
list
(
name_to_sequence
.
keys
())
return
Msa
(
sequences
=
msa
,
deletion_matrix
=
deletion_matrix
,
descriptions
=
list
(
name_to_sequence
.
keys
())
)
def
parse_a3m
(
a3m_string
:
str
)
->
Tuple
[
Sequence
[
str
],
DeletionMatrix
]
:
def
parse_a3m
(
a3m_string
:
str
)
->
Msa
:
"""Parses sequences and deletion matrix from a3m format alignment.
Args:
...
...
@@ -142,7 +173,7 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
"""
sequences
,
_
=
parse_fasta
(
a3m_string
)
sequences
,
descriptions
=
parse_fasta
(
a3m_string
)
deletion_matrix
=
[]
for
msa_sequence
in
sequences
:
deletion_vec
=
[]
...
...
@@ -158,7 +189,11 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table
=
str
.
maketrans
(
""
,
""
,
string
.
ascii_lowercase
)
aligned_sequences
=
[
s
.
translate
(
deletion_table
)
for
s
in
sequences
]
return
aligned_sequences
,
deletion_matrix
return
Msa
(
sequences
=
aligned_sequences
,
deletion_matrix
=
deletion_matrix
,
descriptions
=
descriptions
)
def
_convert_sto_seq_to_a3m
(
...
...
@@ -172,7 +207,9 @@ def _convert_sto_seq_to_a3m(
def
convert_stockholm_to_a3m
(
stockholm_format
:
str
,
max_sequences
:
Optional
[
int
]
=
None
stockholm_format
:
str
,
max_sequences
:
Optional
[
int
]
=
None
,
remove_first_row_gaps
:
bool
=
True
,
)
->
str
:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions
=
{}
...
...
@@ -210,13 +247,19 @@ def convert_stockholm_to_a3m(
# Convert sto format to a3m line by line
a3m_sequences
=
{}
# query_sequence is assumed to be the first sequence
query_sequence
=
next
(
iter
(
sequences
.
values
()))
query_non_gaps
=
[
res
!=
"-"
for
res
in
query_sequence
]
if
(
remove_first_row_gaps
):
# query_sequence is assumed to be the first sequence
query_sequence
=
next
(
iter
(
sequences
.
values
()))
query_non_gaps
=
[
res
!=
"-"
for
res
in
query_sequence
]
for
seqname
,
sto_sequence
in
sequences
.
items
():
a3m_sequences
[
seqname
]
=
""
.
join
(
_convert_sto_seq_to_a3m
(
query_non_gaps
,
sto_sequence
)
)
# Dots are optional in a3m format and are commonly removed.
out_sequence
=
sto_sequence
.
replace
(
'.'
,
''
)
if
(
remove_first_row_gaps
):
out_sequence
=
''
.
join
(
_convert_sto_seq_to_a3m
(
query_non_gaps
,
out_sequence
)
)
a3m_sequences
[
seqname
]
=
out_sequence
fasta_chunks
=
(
f
">
{
k
}
{
descriptions
.
get
(
k
,
''
)
}
\n
{
a3m_sequences
[
k
]
}
"
...
...
@@ -225,6 +268,124 @@ def convert_stockholm_to_a3m(
return
"
\n
"
.
join
(
fasta_chunks
)
+
"
\n
"
# Include terminating newline.
def
_keep_line
(
line
:
str
,
seqnames
:
Set
[
str
])
->
bool
:
"""Function to decide which lines to keep."""
if
not
line
.
strip
():
return
True
if
line
.
strip
()
==
'//'
:
# End tag
return
True
if
line
.
startswith
(
'# STOCKHOLM'
):
# Start tag
return
True
if
line
.
startswith
(
'#=GC RF'
):
# Reference Annotation Line
return
True
if
line
[:
4
]
==
'#=GS'
:
# Description lines - keep if sequence in list.
_
,
seqname
,
_
=
line
.
split
(
maxsplit
=
2
)
return
seqname
in
seqnames
elif
line
.
startswith
(
'#'
):
# Other markup - filter out
return
False
else
:
# Alignment data - keep if sequence in list.
seqname
=
line
.
partition
(
' '
)[
0
]
return
seqname
in
seqnames
def
truncate_stockholm_msa
(
stockholm_msa_path
:
str
,
max_sequences
:
int
)
->
str
:
"""Reads + truncates a Stockholm file while preventing excessive RAM usage."""
seqnames
=
set
()
filtered_lines
=
[]
with
open
(
stockholm_msa_path
)
as
f
:
for
line
in
f
:
if
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname
=
line
.
partition
(
' '
)[
0
]
seqnames
.
add
(
seqname
)
if
len
(
seqnames
)
>=
max_sequences
:
break
f
.
seek
(
0
)
for
line
in
f
:
if
_keep_line
(
line
,
seqnames
):
filtered_lines
.
append
(
line
)
return
''
.
join
(
filtered_lines
)
def
remove_empty_columns_from_stockholm_msa
(
stockholm_msa
:
str
)
->
str
:
"""Removes empty columns (dashes-only) from a Stockholm MSA."""
processed_lines
=
{}
unprocessed_lines
=
{}
for
i
,
line
in
enumerate
(
stockholm_msa
.
splitlines
()):
if
line
.
startswith
(
'#=GC RF'
):
reference_annotation_i
=
i
reference_annotation_line
=
line
# Reached the end of this chunk of the alignment. Process chunk.
_
,
_
,
first_alignment
=
line
.
rpartition
(
' '
)
mask
=
[]
for
j
in
range
(
len
(
first_alignment
)):
for
_
,
unprocessed_line
in
unprocessed_lines
.
items
():
prefix
,
_
,
alignment
=
unprocessed_line
.
rpartition
(
' '
)
if
alignment
[
j
]
!=
'-'
:
mask
.
append
(
True
)
break
else
:
# Every row contained a hyphen - empty column.
mask
.
append
(
False
)
# Add reference annotation for processing with mask.
unprocessed_lines
[
reference_annotation_i
]
=
reference_annotation_line
if
not
any
(
mask
):
# All columns were empty. Output empty lines for chunk.
for
line_index
in
unprocessed_lines
:
processed_lines
[
line_index
]
=
''
else
:
for
line_index
,
unprocessed_line
in
unprocessed_lines
.
items
():
prefix
,
_
,
alignment
=
unprocessed_line
.
rpartition
(
' '
)
masked_alignment
=
''
.
join
(
itertools
.
compress
(
alignment
,
mask
))
processed_lines
[
line_index
]
=
f
'
{
prefix
}
{
masked_alignment
}
'
# Clear raw_alignments.
unprocessed_lines
=
{}
elif
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
unprocessed_lines
[
i
]
=
line
else
:
processed_lines
[
i
]
=
line
return
'
\n
'
.
join
((
processed_lines
[
i
]
for
i
in
range
(
len
(
processed_lines
))))
def
deduplicate_stockholm_msa
(
stockholm_msa
:
str
)
->
str
:
"""Remove duplicate sequences (ignoring insertions wrt query)."""
sequence_dict
=
collections
.
defaultdict
(
str
)
# First we must extract all sequences from the MSA.
for
line
in
stockholm_msa
.
splitlines
():
# Only consider the alignments - ignore reference annotation, empty lines,
# descriptions or markup.
if
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
line
=
line
.
strip
()
seqname
,
alignment
=
line
.
split
()
sequence_dict
[
seqname
]
+=
alignment
seen_sequences
=
set
()
seqnames
=
set
()
# First alignment is the query.
query_align
=
next
(
iter
(
sequence_dict
.
values
()))
mask
=
[
c
!=
'-'
for
c
in
query_align
]
# Mask is False for insertions.
for
seqname
,
alignment
in
sequence_dict
.
items
():
# Apply mask to remove all insertions from the string.
masked_alignment
=
''
.
join
(
itertools
.
compress
(
alignment
,
mask
))
if
masked_alignment
in
seen_sequences
:
continue
else
:
seen_sequences
.
add
(
masked_alignment
)
seqnames
.
add
(
seqname
)
filtered_lines
=
[]
for
line
in
stockholm_msa
.
splitlines
():
if
_keep_line
(
line
,
seqnames
):
filtered_lines
.
append
(
line
)
return
'
\n
'
.
join
(
filtered_lines
)
+
'
\n
'
def
_get_hhr_line_regex_groups
(
regex_pattern
:
str
,
line
:
str
)
->
Sequence
[
Optional
[
str
]]:
...
...
@@ -278,7 +439,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
"Could not parse section: %s. Expected this:
\n
%s to contain summary."
%
(
detailed_lines
,
detailed_lines
[
2
])
)
(
prob_true
,
e_value
,
_
,
aligned_cols
,
_
,
_
,
sum_probs
,
neff
)
=
[
(
_
,
_
,
_
,
aligned_cols
,
_
,
_
,
sum_probs
,
_
)
=
[
float
(
x
)
for
x
in
match
.
groups
()
]
...
...
@@ -386,3 +547,98 @@ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
target_name
=
fields
[
0
]
e_values
[
target_name
]
=
float
(
e_value
)
return
e_values
def
_get_indices
(
sequence
:
str
,
start
:
int
)
->
List
[
int
]:
"""Returns indices for non-gap/insert residues starting at the given index."""
indices
=
[]
counter
=
start
for
symbol
in
sequence
:
# Skip gaps but add a placeholder so that the alignment is preserved.
if
symbol
==
'-'
:
indices
.
append
(
-
1
)
# Skip deleted residues, but increase the counter.
elif
symbol
.
islower
():
counter
+=
1
# Normal aligned residue. Increase the counter and append to indices.
else
:
indices
.
append
(
counter
)
counter
+=
1
return
indices
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
HitMetadata
:
pdb_id
:
str
chain
:
str
start
:
int
end
:
int
length
:
int
text
:
str
def
_parse_hmmsearch_description
(
description
:
str
)
->
HitMetadata
:
"""Parses the hmmsearch A3M sequence description line."""
# Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
# Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352
match
=
re
.
match
(
r
'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$'
,
description
.
strip
())
if
not
match
:
raise
ValueError
(
f
'Could not parse description: "
{
description
}
".'
)
return
HitMetadata
(
pdb_id
=
match
[
1
],
chain
=
match
[
2
],
start
=
int
(
match
[
3
]),
end
=
int
(
match
[
4
]),
length
=
int
(
match
[
5
]),
text
=
match
[
6
]
)
def
parse_hmmsearch_a3m
(
query_sequence
:
str
,
a3m_string
:
str
,
skip_first
:
bool
=
True
)
->
Sequence
[
TemplateHit
]:
"""Parses an a3m string produced by hmmsearch.
Args:
query_sequence: The query sequence.
a3m_string: The a3m string produced by hmmsearch.
skip_first: Whether to skip the first sequence in the a3m string.
Returns:
A sequence of `TemplateHit` results.
"""
# Zip the descriptions and MSAs together, skip the first query sequence.
parsed_a3m
=
list
(
zip
(
*
parse_fasta
(
a3m_string
)))
if
skip_first
:
parsed_a3m
=
parsed_a3m
[
1
:]
indices_query
=
_get_indices
(
query_sequence
,
start
=
0
)
hits
=
[]
for
i
,
(
hit_sequence
,
hit_description
)
in
enumerate
(
parsed_a3m
,
start
=
1
):
if
'mol:protein'
not
in
hit_description
:
continue
# Skip non-protein chains.
metadata
=
_parse_hmmsearch_description
(
hit_description
)
# Aligned columns are only the match states.
aligned_cols
=
sum
([
r
.
isupper
()
and
r
!=
'-'
for
r
in
hit_sequence
])
indices_hit
=
_get_indices
(
hit_sequence
,
start
=
metadata
.
start
-
1
)
hit
=
TemplateHit
(
index
=
i
,
name
=
f
'
{
metadata
.
pdb_id
}
_
{
metadata
.
chain
}
'
,
aligned_cols
=
aligned_cols
,
sum_probs
=
None
,
query
=
query_sequence
,
hit_sequence
=
hit_sequence
.
upper
(),
indices_query
=
indices_query
,
indices_hit
=
indices_hit
,
)
hits
.
append
(
hit
)
return
hits
openfold/np/protein.py
View file @
e699d7d2
...
...
@@ -28,6 +28,11 @@ FeatureDict = Mapping[str, np.ndarray]
ModelOutput
=
Mapping
[
str
,
Any
]
# Is a nested dict.
PICO_TO_ANGSTROM
=
0.01
PDB_CHAIN_IDS
=
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS
=
len
(
PDB_CHAIN_IDS
)
assert
(
PDB_MAX_CHAINS
==
62
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Protein
:
"""Protein structure representation."""
...
...
@@ -46,12 +51,23 @@ class Protein:
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index
:
np
.
ndarray
# [num_res]
# 0-indexed number corresponding to the chain in the protein that this
# residue belongs to
chain_index
:
np
.
ndarray
# [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
def
__post_init__
(
self
):
if
(
len
(
np
.
unique
(
self
.
chain_index
))
>
PDB_MAX_CHAINS
:
raise
ValueError
(
f
"Cannot build an instance with more than
{
PDB_MAX_CHAINS
}
"
"chains because these cannot be written to PDB format"
)
def
from_pdb_string
(
pdb_str
:
str
,
chain_id
:
Optional
[
str
]
=
None
)
->
Protein
:
"""Takes a PDB string and constructs a Protein object.
...
...
@@ -61,9 +77,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
Args:
pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which
will be parsed). If chain_id is specified (e.g. A), then only that chain
is parsed.
chain_id: If chain_id is specified (e.g. A), then only that chain is
parsed. Else, all chains are parsed.
Returns:
A new `Protein` parsed from the pdb contents.
...
...
@@ -78,59 +93,61 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
)
model
=
models
[
0
]
if
chain_id
is
not
None
:
chain
=
model
[
chain_id
]
else
:
chains
=
list
(
model
.
get_chains
())
if
len
(
chains
)
!=
1
:
raise
ValueError
(
"Only single chain PDBs are supported when chain_id not specified. "
f
"Found
{
len
(
chains
)
}
chains."
)
else
:
chain
=
chains
[
0
]
atom_positions
=
[]
aatype
=
[]
atom_mask
=
[]
residue_index
=
[]
chain_ids
=
[]
b_factors
=
[]
for
res
in
chain
:
if
res
.
id
[
2
]
!=
" "
:
raise
ValueError
(
f
"PDB contains an insertion code at chain
{
chain
.
id
}
and residue "
f
"index
{
res
.
id
[
1
]
}
. These are not supported."
for
chain
in
model
:
if
(
chain_id
is
not
None
and
chain
.
id
!=
chain_id
):
continue
for
res
in
chain
:
if
res
.
id
[
2
]
!=
" "
:
raise
ValueError
(
f
"PDB contains an insertion code at chain
{
chain
.
id
}
and residue "
f
"index
{
res
.
id
[
1
]
}
. These are not supported."
)
res_shortname
=
residue_constants
.
restype_3to1
.
get
(
res
.
resname
,
"X"
)
restype_idx
=
residue_constants
.
restype_order
.
get
(
res_shortname
,
residue_constants
.
restype_num
)
res_shortname
=
residue_constants
.
restype_3to1
.
get
(
res
.
resname
,
"X"
)
restype_idx
=
residue_constants
.
restype_order
.
get
(
res_shortname
,
residue_constants
.
restype_num
)
pos
=
np
.
zeros
((
residue_constants
.
atom_type_num
,
3
))
mask
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
res_b_factors
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
for
atom
in
res
:
if
atom
.
name
not
in
residue_constants
.
atom_types
:
pos
=
np
.
zeros
((
residue_constants
.
atom_type_num
,
3
))
mask
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
res_b_factors
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
for
atom
in
res
:
if
atom
.
name
not
in
residue_constants
.
atom_types
:
continue
pos
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
coord
mask
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
1.0
res_b_factors
[
residue_constants
.
atom_order
[
atom
.
name
]
]
=
atom
.
bfactor
if
np
.
sum
(
mask
)
<
0.5
:
# If no known atom positions are reported for the residue then skip it.
continue
pos
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
coord
mask
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
1.0
res_b_factors
[
residue_constants
.
atom_order
[
atom
.
name
]
]
=
atom
.
bfactor
if
np
.
sum
(
mask
)
<
0.5
:
# If no known atom positions are reported for the residue then skip it.
continue
aatype
.
append
(
restype_idx
)
atom_positions
.
append
(
pos
)
atom_mask
.
append
(
mask
)
residue_index
.
append
(
res
.
id
[
1
])
b_factors
.
append
(
res_b_factors
)
aatype
.
append
(
restype_idx
)
atom_positions
.
append
(
pos
)
atom_mask
.
append
(
mask
)
residue_index
.
append
(
mask
)
chain_ids
.
append
(
chain
.
id
)
b_factors
.
append
(
res_b_factors
)
# Chain IDs are usually characters so map these to ints
unique_chain_ids
=
np
.
unique
(
chain_ids
)
chain_id_mapping
=
{
cid
:
n
for
n
,
cid
in
enumerate
(
unique_chain_ids
)}
chain_index
=
np
.
array
([
chain_id_mapping
[
cid
]
for
cid
in
chain_ids
])
return
Protein
(
atom_positions
=
np
.
array
(
atom_positions
),
atom_mask
=
np
.
array
(
atom_mask
),
aatype
=
np
.
array
(
aatype
),
residue_index
=
np
.
array
(
residue_index
),
chain_index
=
chain_index
,
b_factors
=
np
.
array
(
b_factors
),
)
...
...
@@ -188,6 +205,14 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
def
_chain_end
(
atom_index
,
end_resname
,
chain_name
,
residue_indx
)
->
str
:
chain_end
=
'TER'
return
(
f
'
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
end_resname
:
>
3
}
'
f
'
{
chain_name
:
>
1
}{
residue_index
:
>
4
}
'
)
def
to_pdb
(
prot
:
Protein
)
->
str
:
"""Converts a `Protein` instance to a PDB string.
...
...
@@ -207,16 +232,39 @@ def to_pdb(prot: Protein) -> str:
aatype
=
prot
.
aatype
atom_positions
=
prot
.
atom_positions
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
chain_index
=
prot
.
chain_index
.
astype
(
np
.
int32
)
b_factors
=
prot
.
b_factors
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
raise
ValueError
(
"Invalid aatypes."
)
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids
=
{}
for
i
in
np
.
unique
(
chain_index
):
# np.unique gives sorted output.
if
i
>=
PDB_MAX_CHAINS
:
raise
ValueError
(
f
"The PDB format supports at most
{
PDB_MAX_CHAINS
}
chains."
)
chain_ids
[
i
]
=
PDB_CHAIN_IDS
[
i
]
pdb_lines
.
append
(
"MODEL 1"
)
atom_index
=
1
chain_i
d
=
"A"
last_
chain_i
ndex
=
chain_index
[
0
]
# Add all atom sites.
for
i
in
range
(
aatype
.
shape
[
0
]):
# Close the previous chain if in a multichain PDB.
if
last_chain_index
!=
chain_index
[
i
]:
pdb_lines
.
append
(
_chain_end
(
atom_index
,
res_1to3
(
aatype
[
i
-
1
]),
chain_ids
[
chain_index
[
i
-
1
]],
residue_index
[
i
-
1
]
)
)
last_chain_index
=
chain_index
[
i
]
atom_index
+=
1
# Atom index increases at the TER symbol.
res_name_3
=
res_1to3
(
aatype
[
i
])
for
atom_name
,
pos
,
mask
,
b_factor
in
zip
(
atom_types
,
atom_positions
[
i
],
atom_mask
[
i
],
b_factors
[
i
]
...
...
@@ -236,7 +284,7 @@ def to_pdb(prot: Protein) -> str:
# PDB is a columnar format, every space matters here!
atom_line
=
(
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
f
"
{
res_name_3
:
>
3
}
{
chain_id
:
>
1
}
"
f
"
{
res_name_3
:
>
3
}
{
chain_id
s
[
chain_index
[
i
]]
:
>
1
}
"
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
f
"
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
"
...
...
@@ -245,18 +293,22 @@ def to_pdb(prot: Protein) -> str:
pdb_lines
.
append
(
atom_line
)
atom_index
+=
1
# Close the chain.
chain_end
=
"TER"
chain_termination_line
=
(
f
"
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
res_1to3
(
aatype
[
-
1
]):
>
3
}
"
f
"
{
chain_id
:
>
1
}{
residue_index
[
-
1
]:
>
4
}
"
# Close the final chain.
pdb_lines
.
append
(
_chain_end
(
atom_index
,
res_1to3
(
aatype
[
-
1
]),
chain_ids
[
chain_index
[
-
1
]],
residue_index
[
-
1
]
)
)
pdb_lines
.
append
(
chain_termination_line
)
pdb_lines
.
append
(
"ENDMDL"
)
pdb_lines
.
append
(
"END"
)
pdb_lines
.
append
(
""
)
return
"
\n
"
.
join
(
pdb_lines
)
# Pad all lines to 80 characters
pdb_lines
=
[
line
.
ljust
(
80
)
for
line
in
pdb_lines
]
return
'
\n
'
.
join
(
pdb_lines
)
+
'
\n
'
# Add terminating newline.
def
ideal_atom_mask
(
prot
:
Protein
)
->
np
.
ndarray
:
...
...
@@ -279,6 +331,7 @@ def from_prediction(
features
:
FeatureDict
,
result
:
ModelOutput
,
b_factors
:
Optional
[
np
.
ndarray
]
=
None
,
remove_leading_feature_dimension
:
bool
=
True
,
)
->
Protein
:
"""Assembles a protein from a prediction.
...
...
@@ -286,17 +339,30 @@ def from_prediction(
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
Returns:
A protein instance.
"""
def
_maybe_remove_leading_dim
(
arr
:
np
.
ndarray
)
->
np
.
ndarray
:
return
arr
[
0
]
if
remove_leading_feature_dimension
else
arr
if
'asym_id'
in
features
:
chain_index
=
_maybe_remove_leading_dim
(
features
[
"asym_id"
])
else
:
chain_index
=
np
.
zeros_like
(
_maybe_remove_leading_dim
(
features
[
"aatype"
])
)
if
b_factors
is
None
:
b_factors
=
np
.
zeros_like
(
result
[
"final_atom_mask"
])
return
Protein
(
aatype
=
features
[
"aatype"
],
aatype
=
_maybe_remove_leading_dim
(
features
[
"aatype"
]
)
,
atom_positions
=
result
[
"final_atom_positions"
],
atom_mask
=
result
[
"final_atom_mask"
],
residue_index
=
features
[
"residue_index"
]
+
1
,
residue_index
=
_maybe_remove_leading_dim
(
features
[
"residue_index"
])
+
1
,
chain_index
=
chain_index
,
b_factors
=
b_factors
,
)
openfold/np/residue_constants.py
View file @
e699d7d2
...
...
@@ -17,6 +17,7 @@
import
collections
import
functools
import
os
from
typing
import
Mapping
,
List
,
Tuple
from
importlib
import
resources
...
...
@@ -448,9 +449,9 @@ def load_stereo_chemical_props() -> Tuple[
("residue_virtual_bonds").
Returns:
residue_bonds:
d
ict that maps resname
-
-> list of Bond tuples
residue_virtual_bonds:
d
ict that maps resname
-
-> list of Bond tuples
residue_bond_angles:
d
ict that maps resname
-
-> list of BondAngle tuples
residue_bonds:
D
ict that maps resname -> list of Bond tuples
residue_virtual_bonds:
D
ict that maps resname -> list of Bond tuples
residue_bond_angles:
D
ict that maps resname -> list of BondAngle tuples
"""
# TODO: this file should be downloaded in a setup script
stereo_chemical_props
=
resources
.
read_text
(
"openfold.resources"
,
"stereo_chemical_props.txt"
)
...
...
openfold/utils/loss.py
View file @
e699d7d2
...
...
@@ -619,6 +619,8 @@ def compute_predicted_aligned_error(
def
compute_tm
(
logits
:
torch
.
Tensor
,
residue_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
asym_id
:
Optional
[
torch
.
Tensor
]
=
None
,
interface
:
bool
=
False
,
max_bin
:
int
=
31
,
no_bins
:
int
=
64
,
eps
:
float
=
1e-8
,
...
...
@@ -632,9 +634,9 @@ def compute_tm(
)
bin_centers
=
_calculate_bin_centers
(
boundaries
)
torch
.
sum
(
residue_weights
)
n
=
logits
.
shape
[
-
2
]
clipped_n
=
max
(
n
,
19
)
soft_n
=
torch
.
sum
(
residue_weights
,
dim
=-
1
).
to
(
torch
.
int32
)
other
=
n
.
new_zeros
()
+
19
clipped_n
=
torch
.
max
(
soft_n
,
other
,
dim
=-
1
)
d0
=
1.24
*
(
clipped_n
-
15
)
**
(
1.0
/
3
)
-
1.8
...
...
@@ -643,11 +645,22 @@ def compute_tm(
tm_per_bin
=
1.0
/
(
1
+
(
bin_centers
**
2
)
/
(
d0
**
2
))
predicted_tm_term
=
torch
.
sum
(
probs
*
tm_per_bin
,
dim
=-
1
)
normed_residue_mask
=
residue_weights
/
(
eps
+
residue_weights
.
sum
())
n
=
residue_weights
.
shape
[
-
1
]
pair_mask
=
residue_weights
.
new_ones
((
n
,
n
),
dtype
=
torch
.
int32
)
if
interface
:
pair_mask
*=
(
asym_id
[...,
None
]
!=
asym_id
[...,
None
,
:])
predicted_tm_term
*=
pair_mask
pair_residue_weights
=
pair_mask
*
(
residue_weights
[...,
None
,
:]
*
residue_weights
[...,
:,
None
]
)
denom
=
eps
+
torch
.
sum
(
pair_residue_weights
,
dim
=-
1
,
keepdims
=
True
)
normed_residue_mask
=
pair_residue_weights
/
denom
per_alignment
=
torch
.
sum
(
predicted_tm_term
*
normed_residue_mask
,
dim
=-
1
)
weighted
=
per_alignment
*
residue_weights
argma
x
=
(
weighted
==
torch
.
max
(
weighted
)).
nonzero
()[
0
]
return
per_alignment
[
tuple
(
argmax
)]
id
x
=
weighted
.
argmax
(
dim
=-
1
,
keepdim
=
True
)
return
torch
.
gather
(
per_alignment
,
-
1
,
idx
).
squeeze
(
-
1
)
def
tm_loss
(
...
...
@@ -701,7 +714,7 @@ def tm_loss(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
# Average over the
loss
dimension
# Average over the
batch
dimension
loss
=
torch
.
mean
(
loss
)
return
loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment