Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
e5d72d41
Commit
e5d72d41
authored
Dec 07, 2022
by
zhuww
Browse files
infer longer sequences using memory optimization
parent
52919c63
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
231 additions
and
69 deletions
+231
-69
fastfold/distributed/comm.py
fastfold/distributed/comm.py
+61
-4
fastfold/model/fastnn/embedders.py
fastfold/model/fastnn/embedders.py
+2
-0
fastfold/model/fastnn/evoformer.py
fastfold/model/fastnn/evoformer.py
+9
-4
fastfold/model/fastnn/msa.py
fastfold/model/fastnn/msa.py
+13
-5
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+5
-0
fastfold/model/fastnn/template.py
fastfold/model/fastnn/template.py
+6
-3
fastfold/model/hub/alphafold.py
fastfold/model/hub/alphafold.py
+51
-24
fastfold/model/nn/structure_module.py
fastfold/model/nn/structure_module.py
+66
-13
inference.py
inference.py
+15
-15
inference.sh
inference.sh
+3
-1
No files found.
fastfold/distributed/comm.py
View file @
e5d72d41
...
...
@@ -27,14 +27,19 @@ def _reduce(tensor: Tensor) -> Tensor:
return
tensor
def
_split
(
tensor
:
Tensor
,
dim
:
int
=
-
1
)
->
Tensor
:
def
_split
(
tensor
:
Tensor
,
dim
:
int
=
-
1
,
drop_unused
=
False
)
->
Tensor
:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
return
tensor
split_size
=
divide
(
tensor
.
shape
[
dim
],
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))
tensor_list
=
torch
.
split
(
tensor
,
split_size
,
dim
=
dim
)
output
=
tensor_list
[
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)].
contiguous
()
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
if
not
drop_unused
:
output
=
tensor_list
[
rank
].
contiguous
()
else
:
output
=
tensor_list
[
rank
].
contiguous
().
clone
()
return
output
...
...
@@ -65,6 +70,55 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
return
output
def
_chunk_gather
(
tensor
:
Tensor
,
dim
=-
1
,
chunks
=
1
)
->
Tensor
:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
return
tensor
if
dim
==
1
and
list
(
tensor
.
shape
)[
0
]
==
1
:
output_shape
=
list
(
tensor
.
shape
)
output_shape
[
1
]
*=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
output
=
torch
.
empty
(
output_shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
1
)
tensor_list
=
[]
for
t
in
world_list
:
tensor_list
.
extend
(
t
.
chunk
(
chunks
,
dim
=
1
))
chunk_tensor
=
tensor
.
chunk
(
chunks
,
dim
=
1
)
for
i
in
range
(
chunks
):
_chunk_list
=
[
tensor_list
[
j
*
4
+
i
]
for
j
in
range
(
4
)]
_chunk_tensor
=
chunk_tensor
[
i
]
dist
.
all_gather
(
list
(
_chunk_list
),
_chunk_tensor
,
group
=
gpc
.
get_group
(
ParallelMode
.
TENSOR
),
async_op
=
False
)
torch
.
cuda
.
empty_cache
()
else
:
output_shape
=
list
(
tensor
.
shape
)
output_shape
[
0
]
*=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
output
=
torch
.
empty
(
output_shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
0
)
tensor_list
=
[]
for
t
in
world_list
:
tensor_list
.
extend
(
t
.
chunk
(
chunks
,
dim
=
0
))
chunk_tensor
=
tensor
.
chunk
(
chunks
,
dim
=
0
)
for
i
in
range
(
chunks
):
_chunk_list
=
[
tensor_list
[
j
*
4
+
i
]
for
j
in
range
(
4
)]
_chunk_tensor
=
chunk_tensor
[
i
]
dist
.
all_gather
(
list
(
_chunk_list
),
_chunk_tensor
,
group
=
gpc
.
get_group
(
ParallelMode
.
TENSOR
),
async_op
=
False
)
torch
.
cuda
.
empty_cache
()
return
output
def
copy
(
input
:
Tensor
)
->
Tensor
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
input
=
Copy
.
apply
(
input
)
...
...
@@ -122,11 +176,14 @@ class Reduce(torch.autograd.Function):
return
grad_output
def
gather
(
input
:
Tensor
,
dim
:
int
=
-
1
)
->
Tensor
:
def
gather
(
input
:
Tensor
,
dim
:
int
=
-
1
,
chunks
:
int
=
None
)
->
Tensor
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
input
=
Gather
.
apply
(
input
,
dim
)
else
:
input
=
_gather
(
input
,
dim
=
dim
)
if
chunks
is
None
:
input
=
_gather
(
input
,
dim
=
dim
)
else
:
input
=
_chunk_gather
(
input
,
dim
=
dim
,
chunks
=
chunks
)
return
input
...
...
fastfold/model/fastnn/embedders.py
View file @
e5d72d41
...
...
@@ -224,6 +224,7 @@ class TemplateEmbedder(nn.Module):
).
to
(
t
.
device
)
del
tt
,
single_template_feats
torch
.
cuda
.
empty_cache
()
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
...
...
@@ -245,6 +246,7 @@ class TemplateEmbedder(nn.Module):
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
*
256
if
chunk_size
is
not
None
else
chunk_size
,
)
torch
.
cuda
.
empty_cache
()
ret
=
{}
ret
[
"template_pair_embedding"
]
=
z
...
...
fastfold/model/fastnn/evoformer.py
View file @
e5d72d41
...
...
@@ -105,9 +105,13 @@ class Evoformer(nn.Module):
m
[
0
]
=
torch
.
nn
.
functional
.
pad
(
m
[
0
],
(
0
,
0
,
0
,
padding_size
))
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
torch
.
cuda
.
empty_cache
()
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
)
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
...
...
@@ -137,9 +141,10 @@ class Evoformer(nn.Module):
if
self
.
last_block
:
m
[
0
]
=
m
[
0
].
squeeze
(
0
)
z
[
0
]
=
z
[
0
].
squeeze
(
0
)
torch
.
cuda
.
empty_cache
()
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
,
chunks
=
4
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
,
chunks
=
4
)
m
[
0
]
=
m
[
0
][:,
:
-
padding_size
,
:]
z
[
0
]
=
z
[
0
][:
-
padding_size
,
:
-
padding_size
,
:]
...
...
fastfold/model/fastnn/msa.py
View file @
e5d72d41
...
...
@@ -217,7 +217,7 @@ class ExtraMSABlock(nn.Module):
seq_len
=
pair_mask
.
size
(
-
1
)
seq_cnt_padding_size
=
(
int
(
seq_cnt
/
dap_size
)
+
1
)
*
dap_size
-
seq_cnt
seq_len_padding_size
=
(
int
(
seq_len
/
dap_size
)
+
1
)
*
dap_size
-
seq_len
if
self
.
first_block
:
m
=
m
.
unsqueeze
(
0
)
z
=
z
.
unsqueeze
(
0
)
...
...
@@ -225,6 +225,7 @@ class ExtraMSABlock(nn.Module):
m
=
torch
.
nn
.
functional
.
pad
(
m
,
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_cnt_padding_size
)
)
z
=
torch
.
nn
.
functional
.
pad
(
z
,
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_len_padding_size
)
)
...
...
@@ -286,6 +287,8 @@ class ExtraMSABlock(nn.Module):
seq_len
=
pair_mask
.
size
(
-
1
)
seq_cnt_padding_size
=
(
int
(
seq_cnt
/
dap_size
)
+
1
)
*
dap_size
-
seq_cnt
seq_len_padding_size
=
(
int
(
seq_len
/
dap_size
)
+
1
)
*
dap_size
-
seq_len
torch
.
cuda
.
empty_cache
()
if
self
.
first_block
:
m
[
0
]
=
m
[
0
].
unsqueeze
(
0
)
...
...
@@ -294,12 +297,16 @@ class ExtraMSABlock(nn.Module):
m
[
0
]
=
torch
.
nn
.
functional
.
pad
(
m
[
0
],
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_cnt_padding_size
)
)
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_len_padding_size
)
)
torch
.
cuda
.
empty_cache
()
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
)
if
not
self
.
is_multimer
else
scatter
(
m
[
0
],
dim
=
2
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
)
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
,
drop_unused
=
True
)
if
not
self
.
is_multimer
else
scatter
(
m
[
0
],
dim
=
2
)
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
...
...
@@ -332,8 +339,9 @@ class ExtraMSABlock(nn.Module):
if
self
.
last_block
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
)
if
not
self
.
is_multimer
else
gather
(
m
[
0
],
dim
=
2
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
,
chunks
=
4
)
if
not
self
.
is_multimer
else
gather
(
m
[
0
],
dim
=
2
)
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
,
chunks
=
4
)
m
[
0
]
=
m
[
0
][:,
:
-
seq_cnt_padding_size
,
:
-
seq_len_padding_size
,
:]
z
[
0
]
=
z
[
0
][:,
:
-
seq_len_padding_size
,
:
-
seq_len_padding_size
,
:]
...
...
fastfold/model/fastnn/ops.py
View file @
e5d72d41
...
...
@@ -207,7 +207,9 @@ class OutProductMean(nn.Module):
norm
=
torch
.
einsum
(
'bsid,bsjd->bijd'
,
M_mask_col
,
M_mask
)
+
1e-3
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
torch
.
cuda
.
empty_cache
()
right_act_all
=
M_mask
*
right_act_all
torch
.
cuda
.
empty_cache
()
para_dim
=
left_act
.
shape
[
2
]
chunk_size
=
CHUNK_SIZE
...
...
@@ -1272,11 +1274,14 @@ class InputEmbedder(nn.Module):
reshaped_bins
=
boundaries
.
view
(((
1
,)
*
len
(
d
.
shape
))
+
(
len
(
boundaries
),))
pair_emb
=
d
[...,
None
]
-
reshaped_bins
del
d
,
reshaped_bins
torch
.
cuda
.
empty_cache
()
pair_emb
=
torch
.
argmin
(
torch
.
abs
(
pair_emb
),
dim
=-
1
)
pair_emb
=
nn
.
functional
.
one_hot
(
pair_emb
,
num_classes
=
len
(
boundaries
)).
float
().
type
(
ri
.
dtype
)
pair_emb
=
self
.
linear_relpos
(
pair_emb
)
pair_emb
+=
tf_emb_i
[...,
None
,
:]
pair_emb
+=
tf_emb_j
[...,
None
,
:,
:]
torch
.
cuda
.
empty_cache
()
# [*, N_clust, N_res, c_m]
n_clust
=
msa
.
shape
[
-
3
]
...
...
fastfold/model/fastnn/template.py
View file @
e5d72d41
...
...
@@ -282,9 +282,11 @@ class TemplatePairBlock(nn.Module):
seq_length
=
mask
.
size
(
-
1
)
padding_size
=
(
int
(
seq_length
/
dap_size
)
+
1
)
*
dap_size
-
seq_length
torch
.
cuda
.
empty_cache
()
if
self
.
first_block
:
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
...
...
@@ -294,8 +296,8 @@ class TemplatePairBlock(nn.Module):
single
=
z
[
0
][
i
].
unsqueeze
(
-
4
).
to
(
mask
.
device
)
single_mask
=
mask
[
i
].
unsqueeze
(
-
3
)
single_mask_row
=
scatter
(
single_mask
,
dim
=
1
)
single_mask_col
=
scatter
(
single_mask
,
dim
=
2
)
single_mask_row
=
scatter
(
single_mask
,
dim
=
1
,
drop_unused
=
True
)
single_mask_col
=
scatter
(
single_mask
,
dim
=
2
,
drop_unused
=
True
)
single
=
self
.
TriangleAttentionStartingNode
(
single
,
single_mask_row
)
single
=
row_to_col
(
single
)
...
...
@@ -307,6 +309,7 @@ class TemplatePairBlock(nn.Module):
single
=
self
.
PairTransition
(
single
)
single
=
col_to_row
(
single
)
z
[
0
][
i
]
=
single
.
to
(
z
[
0
].
device
)
# z = torch.cat(single_templates, dim=-4)
if
self
.
last_block
:
...
...
fastfold/model/hub/alphafold.py
View file @
e5d72d41
...
...
@@ -169,7 +169,8 @@ class AlphaFold(nn.Module):
return
ret
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
_recycle
=
True
):
def
iteration
(
self
,
feats
,
prevs
,
_recycle
=
True
,
no_iter
=
0
):
torch
.
cuda
.
empty_cache
()
# Primary output dictionary
outputs
=
{}
...
...
@@ -203,7 +204,10 @@ class AlphaFold(nn.Module):
if
not
self
.
globals
.
is_multimer
else
self
.
input_embedder
(
feats
)
)
torch
.
cuda
.
empty_cache
()
m_1_prev
,
z_prev
,
x_prev
=
reversed
([
prevs
.
pop
()
for
_
in
range
(
3
)])
# Initialize the recycling embeddings, if needs be
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
# [*, N, C_m]
...
...
@@ -251,6 +255,8 @@ class AlphaFold(nn.Module):
# Possibly prevents memory fragmentation
del
m_1_prev
,
z_prev
,
x_prev
torch
.
cuda
.
empty_cache
()
# Embed the templates + merge with MSA/pair embeddings
if
self
.
config
.
template
.
enabled
:
...
...
@@ -330,6 +336,8 @@ class AlphaFold(nn.Module):
# [*, S_e, N, C_e]
extra_msa_feat
=
extra_msa_fn
(
feats
)
extra_msa_feat
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
torch
.
cuda
.
empty_cache
()
# [*, N, N, C_z]
if
not
self
.
globals
.
inplace
:
...
...
@@ -353,6 +361,8 @@ class AlphaFold(nn.Module):
_mask_trans
=
self
.
config
.
_mask_trans
,
)[
0
]
del
extra_msa_feat
,
extra_msa_fn
torch
.
cuda
.
empty_cache
()
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
...
...
@@ -380,36 +390,49 @@ class AlphaFold(nn.Module):
)
m
=
m
[
0
]
z
=
z
[
0
]
torch
.
cuda
.
empty_cache
()
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"pair"
]
=
z
outputs
[
"single"
]
=
s
if
no_iter
==
3
:
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"pair"
]
=
z
outputs
[
"single"
]
=
s
# Predict 3D structure
outputs
[
"sm"
]
=
self
.
structure_module
(
z
=
[
z
]
outputs_sm
=
self
.
structure_module
(
s
,
z
,
feats
[
"aatype"
],
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
)
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
)
outputs
[
"final_atom_mask"
]
=
feats
[
"atom37_atom_exists"
]
outputs
[
"final_affine_tensor"
]
=
outputs
[
"sm"
][
"frames"
][
-
1
]
# Save embeddings for use during the next recycling iteration
# [*, N, C_m]
m_1_prev
=
m
[...,
0
,
:,
:]
torch
.
cuda
.
empty_cache
()
if
no_iter
==
3
:
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
outputs
[
"sm"
]
=
outputs_sm
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
)
outputs
[
"final_atom_mask"
]
=
feats
[
"atom37_atom_exists"
]
outputs
[
"final_affine_tensor"
]
=
outputs
[
"sm"
][
"frames"
][
-
1
]
else
:
# Save embeddings for use during the next recycling iteration
# [*, N,
N,
C_
z
]
z
_prev
=
z
# [*, N, C_
m
]
m_1
_prev
=
m
[...,
0
,
:,
:]
# [*, N,
3
]
x
_prev
=
outputs
[
"final_atom_positions"
]
# [*, N,
N, C_z
]
z
_prev
=
z
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
# [*, N, 3]
x_prev
=
outputs
[
"final_atom_positions"
]
if
no_iter
!=
3
:
return
None
,
m_1_prev
,
z_prev
,
x_prev
else
:
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
def
_disable_activation_checkpointing
(
self
):
self
.
template_embedder
.
template_pair_stack
.
blocks_per_ckpt
=
None
...
...
@@ -482,6 +505,7 @@ class AlphaFold(nn.Module):
"""
# Initialize recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled
=
torch
.
is_grad_enabled
()
...
...
@@ -506,11 +530,14 @@ class AlphaFold(nn.Module):
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
m_1_prev
,
z_prev
,
x_prev
,
_recycle
=
(
num_iters
>
1
)
prevs
,
_recycle
=
(
num_iters
>
1
),
no_iter
=
cycle_no
)
if
cycle_no
!=
3
:
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
del
m_1_prev
,
z_prev
,
x_prev
# Run auxiliary heads
outputs
.
update
(
self
.
aux_heads
(
outputs
))
...
...
fastfold/model/nn/structure_module.py
View file @
e5d72d41
...
...
@@ -389,6 +389,8 @@ class InvariantPointAttention(nn.Module):
)
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
+=
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
))
torch
.
cuda
.
empty_cache
()
if
self
.
is_multimer
:
# [*, N_res, N_res, H, P_q, 3]
...
...
@@ -396,11 +398,27 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H, P_q]
pt_att
=
sum
([
c
**
2
for
c
in
pt_att
])
else
:
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
pt_att
**
2
# [*, N_res, N_res, H, P_q]
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
# # [*, N_res, N_res, H, P_q, 3]
# pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
# pt_att = pt_att**2
# # [*, N_res, N_res, H, P_q]
# pt_att = sum(torch.unbind(pt_att, dim=-1))
_chunks
=
10
_ks
=
k_pts
.
unsqueeze
(
-
5
)
_qlist
=
torch
.
chunk
(
q_pts
.
unsqueeze
(
-
4
),
chunks
=
_chunks
,
dim
=
0
)
pt_att
=
None
for
_i
in
range
(
_chunks
):
_pt
=
_qlist
[
_i
]
-
_ks
_pt
=
_pt
**
2
_pt
=
sum
(
torch
.
unbind
(
_pt
,
dim
=-
1
))
if
_i
==
0
:
pt_att
=
_pt
else
:
pt_att
=
torch
.
cat
([
pt_att
,
_pt
],
dim
=
0
)
torch
.
cuda
.
empty_cache
()
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
...
...
@@ -430,6 +448,8 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
torch
.
cuda
.
empty_cache
()
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
...
...
@@ -447,14 +467,34 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v]
o_pt_norm
=
o_pt
.
norm
(
self
.
eps
)
else
:
# [*, H, 3, N_res, P_v]
o_pt
=
torch
.
sum
(
(
a
[...,
None
,
:,
:,
None
]
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
),
dim
=-
2
,
)
# chunk permuted_pts
permuted_pts
=
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
size0
=
permuted_pts
.
size
()[
0
]
a_size0
=
a
.
size
()[
0
]
if
size0
==
1
or
size0
!=
a_size0
:
# # [*, H, 3, N_res, P_v]
o_pt
=
torch
.
sum
(
(
a
[...,
None
,
:,
:,
None
]
*
permuted_pts
),
dim
=-
2
,
)
else
:
a_lists
=
torch
.
chunk
(
a
[...,
None
,
:,
:,
None
],
size0
,
dim
=
0
)
permuted_pts_lists
=
torch
.
chunk
(
permuted_pts
,
size0
,
dim
=
0
)
_c
=
None
for
i
in
range
(
size0
):
_d
=
a_lists
[
i
]
*
permuted_pts_lists
[
i
]
_d
=
torch
.
sum
(
_d
[...,
None
,
:,
:],
dim
=-
2
)
if
i
==
0
:
_c
=
_d
else
:
_c
=
torch
.
cat
([
_c
,
_d
],
dim
=
0
)
o_pt
=
torch
.
sum
(
_c
,
dim
=-
2
)
del
permuted_pts
# [*, N_res, H, P_v, 3]
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
...
...
@@ -471,8 +511,11 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
))
torch
.
cuda
.
empty_cache
()
# [*, N_res, H * C_z]
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
del
a
# [*, N_res, C_s]
if
self
.
is_multimer
:
...
...
@@ -485,6 +528,7 @@ class InvariantPointAttention(nn.Module):
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
),
o_pt_norm
,
o_pair
),
dim
=-
1
).
to
(
dtype
=
z
.
dtype
)
)
torch
.
cuda
.
empty_cache
()
return
s
...
...
@@ -720,10 +764,16 @@ class StructureModule(nn.Module):
# [*, N, N, C_z]
z
=
self
.
layer_norm_z
(
z
)
# inplace z
# z[0] = z[0].contiguous()
# torch.cuda.emtpy_cache()
# z[0] = self.layer_norm_z(z[0])
# [*, N, C_s]
s_initial
=
s
s
=
self
.
linear_in
(
s
)
torch
.
cuda
.
empty_cache
()
# [*, N]
rigids
=
Rigid
.
identity
(
...
...
@@ -737,9 +787,12 @@ class StructureModule(nn.Module):
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
del
z
s
=
self
.
ipa_dropout
(
s
)
torch
.
cuda
.
empty_cache
()
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
torch
.
cuda
.
empty_cache
()
# [*, N]
rigids
=
rigids
.
compose_q_update_vec
(
self
.
bb_update
(
s
))
...
...
inference.py
View file @
e5d72d41
...
...
@@ -434,21 +434,21 @@ def inference_monomer_model(args):
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
True
,
**
config
.
relax
,
)
# Relax the prediction.
t
=
time
.
perf_counter
()
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
print
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
#
amber_relaxer = relax.AmberRelaxation(
#
use_gpu=True,
#
**config.relax,
#
)
#
#
Relax the prediction.
#
t = time.perf_counter()
#
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
#
print(f"Relaxation time: {time.perf_counter() - t}")
#
#
Save the relaxed PDB.
#
relaxed_output_path = os.path.join(args.output_dir,
#
f'{tag}_{args.model_name}_relaxed.pdb')
#
with open(relaxed_output_path, 'w') as f:
#
f.write(relaxed_pdb_str)
if
__name__
==
"__main__"
:
...
...
inference.sh
View file @
e5d72d41
...
...
@@ -14,4 +14,6 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files \
--jackhmmer_binary_path
`
which jackhmmer
`
\
--hhblits_binary_path
`
which hhblits
`
\
--hhsearch_binary_path
`
which hhsearch
`
\
--kalign_binary_path
`
which kalign
`
--kalign_binary_path
`
which kalign
`
\
--chunk_size
4
\
--inplace
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment