Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
b7ee0ff3
Unverified
Commit
b7ee0ff3
authored
Jan 05, 2023
by
shenggan
Committed by
GitHub
Jan 05, 2023
Browse files
fix wrong para import and update protein class (#130)
parent
a80d5263
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
181 additions
and
84 deletions
+181
-84
fastfold/common/protein.py
fastfold/common/protein.py
+161
-83
fastfold/data/data_pipeline.py
fastfold/data/data_pipeline.py
+1
-1
fastfold/utils/import_weights.py
fastfold/utils/import_weights.py
+19
-0
No files found.
fastfold/common/protein.py
View file @
b7ee0ff3
...
...
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protein data type."""
import
dataclasses
import
io
...
...
@@ -22,10 +23,15 @@ from fastfold.common import residue_constants
from
Bio.PDB
import
PDBParser
import
numpy
as
np
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
ModelOutput
=
Mapping
[
str
,
Any
]
# Is a nested dict.
PICO_TO_ANGSTROM
=
0.01
PDB_CHAIN_IDS
=
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS
=
len
(
PDB_CHAIN_IDS
)
assert
(
PDB_MAX_CHAINS
==
62
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Protein
:
...
...
@@ -45,25 +51,32 @@ class Protein:
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index
:
np
.
ndarray
# [num_res]
# 0-indexed number corresponding to the chain in the protein that this
# residue belongs to
chain_index
:
np
.
ndarray
# [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
def
__post_init__
(
self
):
if
(
len
(
np
.
unique
(
self
.
chain_index
))
>
PDB_MAX_CHAINS
):
raise
ValueError
(
f
"Cannot build an instance with more than
{
PDB_MAX_CHAINS
}
"
"chains because these cannot be written to PDB format"
)
def
from_pdb_string
(
pdb_str
:
str
,
chain_id
:
Optional
[
str
]
=
None
)
->
Protein
:
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which
will be parsed). If chain_id is specified (e.g. A), then only that chain
is parsed.
chain_id: If chain_id is specified (e.g. A), then only that chain is
parsed. Else, all chains are parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
...
...
@@ -72,95 +85,110 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
structure
=
parser
.
get_structure
(
"none"
,
pdb_fh
)
models
=
list
(
structure
.
get_models
())
if
len
(
models
)
!=
1
:
raise
ValueError
(
f
"Only single model PDBs are supported. Found
{
len
(
models
)
}
models."
)
raise
ValueError
(
f
"Only single model PDBs are supported. Found
{
len
(
models
)
}
models."
)
model
=
models
[
0
]
if
chain_id
is
not
None
:
chain
=
model
[
chain_id
]
else
:
chains
=
list
(
model
.
get_chains
())
if
len
(
chains
)
!=
1
:
raise
ValueError
(
"Only single chain PDBs are supported when chain_id not specified. "
f
"Found
{
len
(
chains
)
}
chains."
)
else
:
chain
=
chains
[
0
]
atom_positions
=
[]
aatype
=
[]
atom_mask
=
[]
residue_index
=
[]
chain_ids
=
[]
b_factors
=
[]
for
res
in
chain
:
if
res
.
id
[
2
]
!=
" "
:
raise
ValueError
(
f
"PDB contains an insertion code at chain
{
chain
.
id
}
and residue "
f
"index
{
res
.
id
[
1
]
}
. These are not supported."
)
res_shortname
=
residue_constants
.
restype_3to1
.
get
(
res
.
resname
,
"X"
)
restype_idx
=
residue_constants
.
restype_order
.
get
(
res_shortname
,
residue_constants
.
restype_num
)
pos
=
np
.
zeros
((
residue_constants
.
atom_type_num
,
3
))
mask
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
res_b_factors
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
for
atom
in
res
:
if
atom
.
name
not
in
residue_constants
.
atom_types
:
continue
pos
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
coord
mask
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
1.0
res_b_factors
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
bfactor
if
np
.
sum
(
mask
)
<
0.5
:
# If no known atom positions are reported for the residue then skip it.
for
chain
in
model
:
if
(
chain_id
is
not
None
and
chain
.
id
!=
chain_id
):
continue
aatype
.
append
(
restype_idx
)
atom_positions
.
append
(
pos
)
atom_mask
.
append
(
mask
)
residue_index
.
append
(
res
.
id
[
1
])
b_factors
.
append
(
res_b_factors
)
for
res
in
chain
:
if
res
.
id
[
2
]
!=
" "
:
raise
ValueError
(
f
"PDB contains an insertion code at chain
{
chain
.
id
}
and residue "
f
"index
{
res
.
id
[
1
]
}
. These are not supported."
)
res_shortname
=
residue_constants
.
restype_3to1
.
get
(
res
.
resname
,
"X"
)
restype_idx
=
residue_constants
.
restype_order
.
get
(
res_shortname
,
residue_constants
.
restype_num
)
pos
=
np
.
zeros
((
residue_constants
.
atom_type_num
,
3
))
mask
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
res_b_factors
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
for
atom
in
res
:
if
atom
.
name
not
in
residue_constants
.
atom_types
:
continue
pos
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
coord
mask
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
1.0
res_b_factors
[
residue_constants
.
atom_order
[
atom
.
name
]
]
=
atom
.
bfactor
if
np
.
sum
(
mask
)
<
0.5
:
# If no known atom positions are reported for the residue then skip it.
continue
aatype
.
append
(
restype_idx
)
atom_positions
.
append
(
pos
)
atom_mask
.
append
(
mask
)
residue_index
.
append
(
res
.
id
[
1
])
chain_ids
.
append
(
chain
.
id
)
b_factors
.
append
(
res_b_factors
)
# Chain IDs are usually characters so map these to ints
unique_chain_ids
=
np
.
unique
(
chain_ids
)
chain_id_mapping
=
{
cid
:
n
for
n
,
cid
in
enumerate
(
unique_chain_ids
)}
chain_index
=
np
.
array
([
chain_id_mapping
[
cid
]
for
cid
in
chain_ids
])
return
Protein
(
atom_positions
=
np
.
array
(
atom_positions
),
atom_mask
=
np
.
array
(
atom_mask
),
aatype
=
np
.
array
(
aatype
),
residue_index
=
np
.
array
(
residue_index
),
chain_index
=
chain_index
,
b_factors
=
np
.
array
(
b_factors
),
)
def
from_proteinnet_string
(
proteinnet_str
:
str
)
->
Protein
:
tag_re
=
r
'(\[[A-Z]+\]\n)'
tags
=
[
tag
.
strip
()
for
tag
in
re
.
split
(
tag_re
,
proteinnet_str
)
if
len
(
tag
)
>
0
]
tags
=
[
tag
.
strip
()
for
tag
in
re
.
split
(
tag_re
,
proteinnet_str
)
if
len
(
tag
)
>
0
]
groups
=
zip
(
tags
[
0
::
2
],
[
l
.
split
(
'
\n
'
)
for
l
in
tags
[
1
::
2
]])
atoms
=
[
'N'
,
'CA'
,
'C'
]
aatype
=
None
atom_positions
=
None
atom_mask
=
None
for
g
in
groups
:
if
(
"[PRIMARY]"
==
g
[
0
]):
if
(
"[PRIMARY]"
==
g
[
0
]):
seq
=
g
[
1
][
0
].
strip
()
for
i
in
range
(
len
(
seq
)):
if
(
seq
[
i
]
not
in
residue_constants
.
restypes
):
if
(
seq
[
i
]
not
in
residue_constants
.
restypes
):
seq
[
i
]
=
'X'
aatype
=
np
.
array
([
residue_constants
.
restype_order
.
get
(
res_symbol
,
residue_constants
.
restype_num
)
for
res_symbol
in
seq
residue_constants
.
restype_order
.
get
(
res_symbol
,
residue_constants
.
restype_num
)
for
res_symbol
in
seq
])
elif
(
"[TERTIARY]"
==
g
[
0
]):
elif
(
"[TERTIARY]"
==
g
[
0
]):
tertiary
=
[]
for
axis
in
range
(
3
):
tertiary
.
append
(
list
(
map
(
float
,
g
[
1
][
axis
].
split
())))
tertiary_np
=
np
.
array
(
tertiary
)
atom_positions
=
np
.
zeros
(
(
len
(
tertiary
[
0
])
//
3
,
residue_constants
.
atom_type_num
,
3
)).
astype
(
np
.
float32
)
(
len
(
tertiary
[
0
])
//
3
,
residue_constants
.
atom_type_num
,
3
)
).
astype
(
np
.
float32
)
for
i
,
atom
in
enumerate
(
atoms
):
atom_positions
[:,
residue_constants
.
atom_order
[
atom
],
:]
=
(
np
.
transpose
(
tertiary_np
[:,
i
::
3
]))
atom_positions
[:,
residue_constants
.
atom_order
[
atom
],
:]
=
(
np
.
transpose
(
tertiary_np
[:,
i
::
3
])
)
atom_positions
*=
PICO_TO_ANGSTROM
elif
(
"[MASK]"
==
g
[
0
]):
elif
(
"[MASK]"
==
g
[
0
]):
mask
=
np
.
array
(
list
(
map
({
'-'
:
0
,
'+'
:
1
}.
get
,
g
[
1
][
0
].
strip
())))
atom_mask
=
np
.
zeros
((
len
(
mask
),
residue_constants
.
atom_type_num
,
)).
astype
(
np
.
float32
)
atom_mask
=
np
.
zeros
(
(
len
(
mask
),
residue_constants
.
atom_type_num
,)
).
astype
(
np
.
float32
)
for
i
,
atom
in
enumerate
(
atoms
):
atom_mask
[:,
residue_constants
.
atom_order
[
atom
]]
=
1
atom_mask
*=
mask
[...,
None
]
...
...
@@ -174,12 +202,18 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
def
_chain_end
(
atom_index
,
end_resname
,
chain_name
,
residue_index
)
->
str
:
chain_end
=
'TER'
return
(
f
'
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
end_resname
:
>
3
}
'
f
'
{
chain_name
:
>
1
}{
residue_index
:
>
4
}
'
)
def
to_pdb
(
prot
:
Protein
)
->
str
:
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
...
...
@@ -193,19 +227,43 @@ def to_pdb(prot: Protein) -> str:
aatype
=
prot
.
aatype
atom_positions
=
prot
.
atom_positions
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
chain_index
=
prot
.
chain_index
.
astype
(
np
.
int32
)
b_factors
=
prot
.
b_factors
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
raise
ValueError
(
"Invalid aatypes."
)
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids
=
{}
for
i
in
np
.
unique
(
chain_index
):
# np.unique gives sorted output.
if
i
>=
PDB_MAX_CHAINS
:
raise
ValueError
(
f
"The PDB format supports at most
{
PDB_MAX_CHAINS
}
chains."
)
chain_ids
[
i
]
=
PDB_CHAIN_IDS
[
i
]
pdb_lines
.
append
(
"MODEL 1"
)
atom_index
=
1
chain_i
d
=
"A"
last_
chain_i
ndex
=
chain_index
[
0
]
# Add all atom sites.
for
i
in
range
(
aatype
.
shape
[
0
]):
# Close the previous chain if in a multichain PDB.
if
last_chain_index
!=
chain_index
[
i
]:
pdb_lines
.
append
(
_chain_end
(
atom_index
,
res_1to3
(
aatype
[
i
-
1
]),
chain_ids
[
chain_index
[
i
-
1
]],
residue_index
[
i
-
1
]
)
)
last_chain_index
=
chain_index
[
i
]
atom_index
+=
1
# Atom index increases at the TER symbol.
res_name_3
=
res_1to3
(
aatype
[
i
])
for
atom_name
,
pos
,
mask
,
b_factor
in
zip
(
atom_types
,
atom_positions
[
i
],
atom_mask
[
i
],
b_factors
[
i
]):
for
atom_name
,
pos
,
mask
,
b_factor
in
zip
(
atom_types
,
atom_positions
[
i
],
atom_mask
[
i
],
b_factors
[
i
]
):
if
mask
<
0.5
:
continue
...
...
@@ -214,40 +272,47 @@ def to_pdb(prot: Protein) -> str:
alt_loc
=
""
insertion_code
=
""
occupancy
=
1.00
element
=
atom_name
[
0
]
# Protein supports only C, N, O, S, this works.
element
=
atom_name
[
0
]
# Protein supports only C, N, O, S, this works.
charge
=
""
# PDB is a columnar format, every space matters here!
atom_line
=
(
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
f
"
{
res_name_3
:
>
3
}
{
chain_id
:
>
1
}
"
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
f
"
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
"
f
"
{
element
:
>
2
}{
charge
:
>
2
}
"
)
atom_line
=
(
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
f
"
{
res_name_3
:
>
3
}
{
chain_ids
[
chain_index
[
i
]]:
>
1
}
"
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
f
"
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
"
f
"
{
element
:
>
2
}{
charge
:
>
2
}
"
)
pdb_lines
.
append
(
atom_line
)
atom_index
+=
1
# Close the chain.
chain_end
=
"TER"
chain_termination_line
=
(
f
"
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
res_1to3
(
aatype
[
-
1
]):
>
3
}
"
f
"
{
chain_id
:
>
1
}{
residue_index
[
-
1
]:
>
4
}
"
)
pdb_lines
.
append
(
chain_termination_line
)
# Close the final chain.
pdb_lines
.
append
(
_chain_end
(
atom_index
,
res_1to3
(
aatype
[
-
1
]),
chain_ids
[
chain_index
[
-
1
]],
residue_index
[
-
1
]
)
)
pdb_lines
.
append
(
"ENDMDL"
)
pdb_lines
.
append
(
"END"
)
pdb_lines
.
append
(
""
)
return
"
\n
"
.
join
(
pdb_lines
)
# Pad all lines to 80 characters
pdb_lines
=
[
line
.
ljust
(
80
)
for
line
in
pdb_lines
]
return
'
\n
'
.
join
(
pdb_lines
)
+
'
\n
'
# Add terminating newline.
def
ideal_atom_mask
(
prot
:
Protein
)
->
np
.
ndarray
:
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
...
...
@@ -258,24 +323,37 @@ def from_prediction(
features
:
FeatureDict
,
result
:
ModelOutput
,
b_factors
:
Optional
[
np
.
ndarray
]
=
None
,
remove_leading_feature_dimension
:
bool
=
False
,
)
->
Protein
:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
Returns:
A protein instance.
"""
def
_maybe_remove_leading_dim
(
arr
:
np
.
ndarray
)
->
np
.
ndarray
:
return
arr
[
0
]
if
remove_leading_feature_dimension
else
arr
if
'asym_id'
in
features
:
chain_index
=
_maybe_remove_leading_dim
(
features
[
"asym_id"
])
else
:
chain_index
=
np
.
zeros_like
(
_maybe_remove_leading_dim
(
features
[
"aatype"
])
)
if
b_factors
is
None
:
b_factors
=
np
.
zeros_like
(
result
[
"final_atom_mask"
])
return
Protein
(
aatype
=
features
[
"aatype"
],
aatype
=
_maybe_remove_leading_dim
(
features
[
"aatype"
]
)
,
atom_positions
=
result
[
"final_atom_positions"
],
atom_mask
=
result
[
"final_atom_mask"
],
residue_index
=
features
[
"residue_index"
]
+
1
,
residue_index
=
_maybe_remove_leading_dim
(
features
[
"residue_index"
])
+
1
,
chain_index
=
chain_index
,
b_factors
=
b_factors
,
)
\ No newline at end of file
)
fastfold/data/data_pipeline.py
View file @
b7ee0ff3
...
...
@@ -249,7 +249,7 @@ def run_msa_tool(
max_sto_sequences
:
Optional
[
int
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
"""Runs an MSA tool, checking if output already exists first."""
if
(
msa_format
==
"sto"
and
max_sto_sequences
is
not
None
):
if
(
msa_format
==
"sto"
):
result
=
msa_runner
.
query
(
fasta_path
,
max_sto_sequences
)[
0
]
else
:
result
=
msa_runner
.
query
(
fasta_path
)
...
...
fastfold/utils/import_weights.py
View file @
b7ee0ff3
...
...
@@ -606,3 +606,22 @@ def import_jax_weights_(model, npz_path, version="model_1"):
# Set weights
assign
(
flat
,
data
)
if
is_fused_triangle_multiplication
():
# (NOTE) in multimer v3, alphafold use fused tri, so need change left/right here
for
b
in
model
.
template_embedder
.
template_pair_stack
.
blocks
:
_change_tri_mul_in_left_right
(
b
.
tri_mul_in
)
for
b
in
model
.
extra_msa_stack
.
blocks
:
_change_tri_mul_in_left_right
(
b
.
core
.
tri_mul_in
)
for
b
in
model
.
evoformer
.
blocks
:
_change_tri_mul_in_left_right
(
b
.
core
.
tri_mul_in
)
def
_change_tri_mul_in_left_right
(
module
):
def
_change_para
(
para
):
left_right_para
=
para
.
clone
().
chunk
(
2
,
dim
=
0
)
return
torch
.
cat
((
left_right_para
[
1
],
left_right_para
[
0
]),
dim
=
0
)
with
torch
.
no_grad
():
module
.
linear_p
.
weight
.
copy_
(
_change_para
(
module
.
linear_p
.
weight
))
module
.
linear_p
.
bias
.
copy_
(
_change_para
(
module
.
linear_p
.
bias
))
module
.
linear_g
.
weight
.
copy_
(
_change_para
(
module
.
linear_g
.
weight
))
module
.
linear_g
.
bias
.
copy_
(
_change_para
(
module
.
linear_g
.
bias
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment