Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
92e6cf49
Unverified
Commit
92e6cf49
authored
Apr 21, 2022
by
shenggan
Committed by
GitHub
Apr 21, 2022
Browse files
Merge pull request #18 from hpcaitech/sync_openfold_591d10d
sync with openfold 591d10d
parents
72444d5b
5f052a0a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
54 additions
and
42 deletions
+54
-42
fastfold/config.py
fastfold/config.py
+2
-14
fastfold/data/data_transforms.py
fastfold/data/data_transforms.py
+31
-12
fastfold/model/hub/alphafold.py
fastfold/model/hub/alphafold.py
+1
-0
fastfold/model/nn/embedders.py
fastfold/model/nn/embedders.py
+9
-12
fastfold/utils/feats.py
fastfold/utils/feats.py
+9
-2
fastfold/utils/rigid_utils.py
fastfold/utils/rigid_utils.py
+2
-2
No files found.
fastfold/config.py
View file @
92e6cf49
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
ml_collections
as
mlc
...
...
@@ -269,6 +255,7 @@ config = mlc.ConfigDict(
"clamp_prob"
:
0.9
,
"max_distillation_msa_clusters"
:
1000
,
"uniform_recycling"
:
True
,
"distillation_prob"
:
0.75
,
},
"data_module"
:
{
"use_small_bfd"
:
False
,
...
...
@@ -347,6 +334,7 @@ config = mlc.ConfigDict(
"eps"
:
eps
,
# 1e-6,
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
"use_unit_vector"
:
False
,
},
"extra_msa"
:
{
"extra_msa_embedder"
:
{
...
...
fastfold/data/data_transforms.py
View file @
92e6cf49
...
...
@@ -50,7 +50,7 @@ def cast_to_64bit_ints(protein):
def
make_one_hot
(
x
,
num_classes
):
x_one_hot
=
torch
.
zeros
(
*
x
.
shape
,
num_classes
)
x_one_hot
=
torch
.
zeros
(
*
x
.
shape
,
num_classes
,
device
=
x
.
device
)
x_one_hot
.
scatter_
(
-
1
,
x
.
unsqueeze
(
-
1
),
1
)
return
x_one_hot
...
...
@@ -92,9 +92,9 @@ def fix_templates_aatype(protein):
)
# Map hhsearch-aatype to our aatype.
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order_list
,
dtype
=
torch
.
int64
).
expand
(
n
um_templates
,
-
1
)
new_order
=
torch
.
tensor
(
n
ew_order_list
,
dtype
=
torch
.
int64
,
device
=
protein
[
"aatype"
].
device
,
)
.
expand
(
num_templates
,
-
1
)
protein
[
"template_aatype"
]
=
torch
.
gather
(
new_order
,
1
,
index
=
protein
[
"template_aatype"
]
)
...
...
@@ -106,7 +106,8 @@ def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as rc."""
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
[
new_order_list
]
*
protein
[
"msa"
].
shape
[
1
],
dtype
=
protein
[
"msa"
].
dtype
[
new_order_list
]
*
protein
[
"msa"
].
shape
[
1
],
device
=
protein
[
"msa"
].
device
,
).
transpose
(
0
,
1
)
protein
[
"msa"
]
=
torch
.
gather
(
new_order
,
0
,
protein
[
"msa"
])
...
...
@@ -187,7 +188,10 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
shuffled
=
torch
.
randperm
(
num_seq
-
1
,
generator
=
g
)
+
1
index_order
=
torch
.
cat
((
torch
.
tensor
([
0
]),
shuffled
),
dim
=
0
)
index_order
=
torch
.
cat
(
(
torch
.
tensor
([
0
],
device
=
shuffled
.
device
),
shuffled
),
dim
=
0
)
num_sel
=
min
(
max_seq
,
num_seq
)
sel_seq
,
not_sel_seq
=
torch
.
split
(
index_order
,
[
num_sel
,
num_seq
-
num_sel
]
...
...
@@ -242,7 +246,7 @@ def delete_extra_msa(protein):
def
block_delete_msa
(
protein
,
config
):
num_seq
=
protein
[
"msa"
].
shape
[
0
]
block_num_seq
=
torch
.
floor
(
torch
.
tensor
(
num_seq
,
dtype
=
torch
.
float32
)
torch
.
tensor
(
num_seq
,
dtype
=
torch
.
float32
,
device
=
protein
[
"msa"
].
device
)
*
config
.
msa_fraction_per_block
).
to
(
torch
.
int32
)
...
...
@@ -275,7 +279,11 @@ def block_delete_msa(protein, config):
@
curry1
def
nearest_neighbor_clusters
(
protein
,
gap_agreement_weight
=
0.0
):
weights
=
torch
.
cat
(
[
torch
.
ones
(
21
),
gap_agreement_weight
*
torch
.
ones
(
1
),
torch
.
zeros
(
1
)],
[
torch
.
ones
(
21
,
device
=
protein
[
"msa"
].
device
),
gap_agreement_weight
*
torch
.
ones
(
1
,
device
=
protein
[
"msa"
].
device
),
torch
.
zeros
(
1
,
device
=
protein
[
"msa"
].
device
)
],
0
,
)
...
...
@@ -324,7 +332,10 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
)
segment_ids
=
segment_ids
.
expand
(
data
.
shape
)
shape
=
[
num_segments
]
+
list
(
data
.
shape
[
1
:])
tensor
=
torch
.
zeros
(
*
shape
).
scatter_add_
(
0
,
segment_ids
,
data
.
float
())
tensor
=
(
torch
.
zeros
(
*
shape
,
device
=
segment_ids
.
device
)
.
scatter_add_
(
0
,
segment_ids
,
data
.
float
())
)
tensor
=
tensor
.
type
(
data
.
dtype
)
return
tensor
...
...
@@ -401,7 +412,7 @@ def make_pseudo_beta(protein, prefix=""):
@
curry1
def
add_constant_field
(
protein
,
key
,
value
):
protein
[
key
]
=
torch
.
tensor
(
value
)
protein
[
key
]
=
torch
.
tensor
(
value
,
device
=
protein
[
"msa"
].
device
)
return
protein
...
...
@@ -431,7 +442,11 @@ def make_hhblits_profile(protein):
def
make_masked_msa
(
protein
,
config
,
replace_fraction
):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa
=
torch
.
tensor
([
0.05
]
*
20
+
[
0.0
,
0.0
],
dtype
=
torch
.
float32
)
random_aa
=
torch
.
tensor
(
[
0.05
]
*
20
+
[
0.0
,
0.0
],
dtype
=
torch
.
float32
,
device
=
protein
[
"aatype"
].
device
)
categorical_probs
=
(
config
.
uniform_prob
*
random_aa
...
...
@@ -644,7 +659,11 @@ def make_atom14_masks(protein):
def
make_atom14_masks_np
(
batch
):
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
,
device
=
batch
[
"aatype"
].
device
),
batch
,
np
.
ndarray
)
out
=
make_atom14_masks
(
batch
)
out
=
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
out
)
return
out
...
...
fastfold/model/hub/alphafold.py
View file @
92e6cf49
...
...
@@ -131,6 +131,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
self
.
config
.
template
.
use_unit_vector
,
inf
=
self
.
config
.
template
.
inf
,
eps
=
self
.
config
.
template
.
eps
,
**
self
.
config
.
template
.
distogram
,
...
...
fastfold/model/nn/embedders.py
View file @
92e6cf49
...
...
@@ -165,8 +165,6 @@ class RecyclingEmbedder(nn.Module):
self
.
no_bins
=
no_bins
self
.
inf
=
inf
self
.
bins
=
None
self
.
linear
=
Linear
(
self
.
no_bins
,
self
.
c_z
)
self
.
layer_norm_m
=
LayerNorm
(
self
.
c_m
)
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
...
...
@@ -191,15 +189,14 @@ class RecyclingEmbedder(nn.Module):
z:
[*, N_res, N_res, C_z] pair embedding update
"""
if
self
.
bins
is
None
:
self
.
bins
=
torch
.
linspace
(
self
.
min_bin
,
self
.
max_bin
,
self
.
no_bins
,
dtype
=
x
.
dtype
,
device
=
x
.
device
,
requires_grad
=
False
,
)
bins
=
torch
.
linspace
(
self
.
min_bin
,
self
.
max_bin
,
self
.
no_bins
,
dtype
=
x
.
dtype
,
device
=
x
.
device
,
requires_grad
=
False
,
)
# [*, N, C_m]
m_update
=
self
.
layer_norm_m
(
m
)
...
...
@@ -207,7 +204,7 @@ class RecyclingEmbedder(nn.Module):
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins
=
self
.
bins
**
2
squared_bins
=
bins
**
2
upper
=
torch
.
cat
(
[
squared_bins
[
1
:],
squared_bins
.
new_tensor
([
self
.
inf
])],
dim
=-
1
)
...
...
fastfold/utils/feats.py
View file @
92e6cf49
...
...
@@ -90,7 +90,10 @@ def build_template_angle_feat(template_feats):
def
build_template_pair_feat
(
batch
,
min_bin
,
max_bin
,
no_bins
,
eps
=
1e-20
,
inf
=
1e8
batch
,
min_bin
,
max_bin
,
no_bins
,
use_unit_vector
=
False
,
eps
=
1e-20
,
inf
=
1e8
):
template_mask
=
batch
[
"template_pseudo_beta_mask"
]
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
...
...
@@ -101,7 +104,7 @@ def build_template_pair_feat(
(
tpb
[...,
None
,
:]
-
tpb
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdim
=
True
)
lower
=
torch
.
linspace
(
min_bin
,
max_bin
,
no_bins
,
device
=
tpb
.
device
)
**
2
upper
=
torch
.
cat
([
lower
[
:
-
1
],
lower
.
new_tensor
([
inf
])],
dim
=-
1
)
upper
=
torch
.
cat
([
lower
[
1
:
],
lower
.
new_tensor
([
inf
])],
dim
=-
1
)
dgram
=
((
dgram
>
lower
)
*
(
dgram
<
upper
)).
type
(
dgram
.
dtype
)
to_concat
=
[
dgram
,
template_mask_2d
[...,
None
]]
...
...
@@ -143,6 +146,10 @@ def build_template_pair_feat(
inv_distance_scalar
=
inv_distance_scalar
*
template_mask_2d
unit_vector
=
rigid_vec
*
inv_distance_scalar
[...,
None
]
if
(
not
use_unit_vector
):
unit_vector
=
unit_vector
*
0.
to_concat
.
extend
(
torch
.
unbind
(
unit_vector
[...,
None
,
:],
dim
=-
1
))
to_concat
.
append
(
template_mask_2d
[...,
None
])
...
...
fastfold/utils/rigid_utils.py
View file @
92e6cf49
...
...
@@ -1352,8 +1352,8 @@ class Rigid:
c2_rots
[...,
0
,
0
]
=
cos_c2
c2_rots
[...,
0
,
2
]
=
sin_c2
c2_rots
[...,
1
,
1
]
=
1
c
1
_rots
[...,
2
,
0
]
=
-
1
*
sin_c2
c
1
_rots
[...,
2
,
2
]
=
cos_c2
c
2
_rots
[...,
2
,
0
]
=
-
1
*
sin_c2
c
2
_rots
[...,
2
,
2
]
=
cos_c2
c_rots
=
rot_matmul
(
c2_rots
,
c1_rots
)
n_xyz
=
rot_vec_mul
(
c_rots
,
n_xyz
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment