Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
e5d72d41
Commit
e5d72d41
authored
Dec 07, 2022
by
zhuww
Browse files
infer longer sequences using memory optimization
parent
52919c63
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
231 additions
and
69 deletions
+231
-69
fastfold/distributed/comm.py
fastfold/distributed/comm.py
+61
-4
fastfold/model/fastnn/embedders.py
fastfold/model/fastnn/embedders.py
+2
-0
fastfold/model/fastnn/evoformer.py
fastfold/model/fastnn/evoformer.py
+9
-4
fastfold/model/fastnn/msa.py
fastfold/model/fastnn/msa.py
+13
-5
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+5
-0
fastfold/model/fastnn/template.py
fastfold/model/fastnn/template.py
+6
-3
fastfold/model/hub/alphafold.py
fastfold/model/hub/alphafold.py
+51
-24
fastfold/model/nn/structure_module.py
fastfold/model/nn/structure_module.py
+66
-13
inference.py
inference.py
+15
-15
inference.sh
inference.sh
+3
-1
No files found.
fastfold/distributed/comm.py
View file @
e5d72d41
...
@@ -27,14 +27,19 @@ def _reduce(tensor: Tensor) -> Tensor:
...
@@ -27,14 +27,19 @@ def _reduce(tensor: Tensor) -> Tensor:
return
tensor
return
tensor
def
_split
(
tensor
:
Tensor
,
dim
:
int
=
-
1
)
->
Tensor
:
def
_split
(
tensor
:
Tensor
,
dim
:
int
=
-
1
,
drop_unused
=
False
)
->
Tensor
:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
return
tensor
return
tensor
split_size
=
divide
(
tensor
.
shape
[
dim
],
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))
split_size
=
divide
(
tensor
.
shape
[
dim
],
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))
tensor_list
=
torch
.
split
(
tensor
,
split_size
,
dim
=
dim
)
tensor_list
=
torch
.
split
(
tensor
,
split_size
,
dim
=
dim
)
output
=
tensor_list
[
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)].
contiguous
()
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
if
not
drop_unused
:
output
=
tensor_list
[
rank
].
contiguous
()
else
:
output
=
tensor_list
[
rank
].
contiguous
().
clone
()
return
output
return
output
...
@@ -65,6 +70,55 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
...
@@ -65,6 +70,55 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
return
output
return
output
def
_chunk_gather
(
tensor
:
Tensor
,
dim
=-
1
,
chunks
=
1
)
->
Tensor
:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
return
tensor
if
dim
==
1
and
list
(
tensor
.
shape
)[
0
]
==
1
:
output_shape
=
list
(
tensor
.
shape
)
output_shape
[
1
]
*=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
output
=
torch
.
empty
(
output_shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
1
)
tensor_list
=
[]
for
t
in
world_list
:
tensor_list
.
extend
(
t
.
chunk
(
chunks
,
dim
=
1
))
chunk_tensor
=
tensor
.
chunk
(
chunks
,
dim
=
1
)
for
i
in
range
(
chunks
):
_chunk_list
=
[
tensor_list
[
j
*
4
+
i
]
for
j
in
range
(
4
)]
_chunk_tensor
=
chunk_tensor
[
i
]
dist
.
all_gather
(
list
(
_chunk_list
),
_chunk_tensor
,
group
=
gpc
.
get_group
(
ParallelMode
.
TENSOR
),
async_op
=
False
)
torch
.
cuda
.
empty_cache
()
else
:
output_shape
=
list
(
tensor
.
shape
)
output_shape
[
0
]
*=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
output
=
torch
.
empty
(
output_shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
0
)
tensor_list
=
[]
for
t
in
world_list
:
tensor_list
.
extend
(
t
.
chunk
(
chunks
,
dim
=
0
))
chunk_tensor
=
tensor
.
chunk
(
chunks
,
dim
=
0
)
for
i
in
range
(
chunks
):
_chunk_list
=
[
tensor_list
[
j
*
4
+
i
]
for
j
in
range
(
4
)]
_chunk_tensor
=
chunk_tensor
[
i
]
dist
.
all_gather
(
list
(
_chunk_list
),
_chunk_tensor
,
group
=
gpc
.
get_group
(
ParallelMode
.
TENSOR
),
async_op
=
False
)
torch
.
cuda
.
empty_cache
()
return
output
def
copy
(
input
:
Tensor
)
->
Tensor
:
def
copy
(
input
:
Tensor
)
->
Tensor
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
input
=
Copy
.
apply
(
input
)
input
=
Copy
.
apply
(
input
)
...
@@ -122,11 +176,14 @@ class Reduce(torch.autograd.Function):
...
@@ -122,11 +176,14 @@ class Reduce(torch.autograd.Function):
return
grad_output
return
grad_output
def
gather
(
input
:
Tensor
,
dim
:
int
=
-
1
)
->
Tensor
:
def
gather
(
input
:
Tensor
,
dim
:
int
=
-
1
,
chunks
:
int
=
None
)
->
Tensor
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
input
=
Gather
.
apply
(
input
,
dim
)
input
=
Gather
.
apply
(
input
,
dim
)
else
:
else
:
input
=
_gather
(
input
,
dim
=
dim
)
if
chunks
is
None
:
input
=
_gather
(
input
,
dim
=
dim
)
else
:
input
=
_chunk_gather
(
input
,
dim
=
dim
,
chunks
=
chunks
)
return
input
return
input
...
...
fastfold/model/fastnn/embedders.py
View file @
e5d72d41
...
@@ -224,6 +224,7 @@ class TemplateEmbedder(nn.Module):
...
@@ -224,6 +224,7 @@ class TemplateEmbedder(nn.Module):
).
to
(
t
.
device
)
).
to
(
t
.
device
)
del
tt
,
single_template_feats
del
tt
,
single_template_feats
torch
.
cuda
.
empty_cache
()
template_embeds
=
dict_multimap
(
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
partial
(
torch
.
cat
,
dim
=
templ_dim
),
...
@@ -245,6 +246,7 @@ class TemplateEmbedder(nn.Module):
...
@@ -245,6 +246,7 @@ class TemplateEmbedder(nn.Module):
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
*
256
if
chunk_size
is
not
None
else
chunk_size
,
chunk_size
=
chunk_size
*
256
if
chunk_size
is
not
None
else
chunk_size
,
)
)
torch
.
cuda
.
empty_cache
()
ret
=
{}
ret
=
{}
ret
[
"template_pair_embedding"
]
=
z
ret
[
"template_pair_embedding"
]
=
z
...
...
fastfold/model/fastnn/evoformer.py
View file @
e5d72d41
...
@@ -105,9 +105,13 @@ class Evoformer(nn.Module):
...
@@ -105,9 +105,13 @@ class Evoformer(nn.Module):
m
[
0
]
=
torch
.
nn
.
functional
.
pad
(
m
[
0
],
(
0
,
0
,
0
,
padding_size
))
m
[
0
]
=
torch
.
nn
.
functional
.
pad
(
m
[
0
],
(
0
,
0
,
0
,
padding_size
))
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
torch
.
cuda
.
empty_cache
()
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
)
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
,
drop_unused
=
True
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
)
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
...
@@ -137,9 +141,10 @@ class Evoformer(nn.Module):
...
@@ -137,9 +141,10 @@ class Evoformer(nn.Module):
if
self
.
last_block
:
if
self
.
last_block
:
m
[
0
]
=
m
[
0
].
squeeze
(
0
)
m
[
0
]
=
m
[
0
].
squeeze
(
0
)
z
[
0
]
=
z
[
0
].
squeeze
(
0
)
z
[
0
]
=
z
[
0
].
squeeze
(
0
)
torch
.
cuda
.
empty_cache
()
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
,
chunks
=
4
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
,
chunks
=
4
)
m
[
0
]
=
m
[
0
][:,
:
-
padding_size
,
:]
m
[
0
]
=
m
[
0
][:,
:
-
padding_size
,
:]
z
[
0
]
=
z
[
0
][:
-
padding_size
,
:
-
padding_size
,
:]
z
[
0
]
=
z
[
0
][:
-
padding_size
,
:
-
padding_size
,
:]
...
...
fastfold/model/fastnn/msa.py
View file @
e5d72d41
...
@@ -217,7 +217,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -217,7 +217,7 @@ class ExtraMSABlock(nn.Module):
seq_len
=
pair_mask
.
size
(
-
1
)
seq_len
=
pair_mask
.
size
(
-
1
)
seq_cnt_padding_size
=
(
int
(
seq_cnt
/
dap_size
)
+
1
)
*
dap_size
-
seq_cnt
seq_cnt_padding_size
=
(
int
(
seq_cnt
/
dap_size
)
+
1
)
*
dap_size
-
seq_cnt
seq_len_padding_size
=
(
int
(
seq_len
/
dap_size
)
+
1
)
*
dap_size
-
seq_len
seq_len_padding_size
=
(
int
(
seq_len
/
dap_size
)
+
1
)
*
dap_size
-
seq_len
if
self
.
first_block
:
if
self
.
first_block
:
m
=
m
.
unsqueeze
(
0
)
m
=
m
.
unsqueeze
(
0
)
z
=
z
.
unsqueeze
(
0
)
z
=
z
.
unsqueeze
(
0
)
...
@@ -225,6 +225,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -225,6 +225,7 @@ class ExtraMSABlock(nn.Module):
m
=
torch
.
nn
.
functional
.
pad
(
m
=
torch
.
nn
.
functional
.
pad
(
m
,
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_cnt_padding_size
)
m
,
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_cnt_padding_size
)
)
)
z
=
torch
.
nn
.
functional
.
pad
(
z
=
torch
.
nn
.
functional
.
pad
(
z
,
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_len_padding_size
)
z
,
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_len_padding_size
)
)
)
...
@@ -286,6 +287,8 @@ class ExtraMSABlock(nn.Module):
...
@@ -286,6 +287,8 @@ class ExtraMSABlock(nn.Module):
seq_len
=
pair_mask
.
size
(
-
1
)
seq_len
=
pair_mask
.
size
(
-
1
)
seq_cnt_padding_size
=
(
int
(
seq_cnt
/
dap_size
)
+
1
)
*
dap_size
-
seq_cnt
seq_cnt_padding_size
=
(
int
(
seq_cnt
/
dap_size
)
+
1
)
*
dap_size
-
seq_cnt
seq_len_padding_size
=
(
int
(
seq_len
/
dap_size
)
+
1
)
*
dap_size
-
seq_len
seq_len_padding_size
=
(
int
(
seq_len
/
dap_size
)
+
1
)
*
dap_size
-
seq_len
torch
.
cuda
.
empty_cache
()
if
self
.
first_block
:
if
self
.
first_block
:
m
[
0
]
=
m
[
0
].
unsqueeze
(
0
)
m
[
0
]
=
m
[
0
].
unsqueeze
(
0
)
...
@@ -294,12 +297,16 @@ class ExtraMSABlock(nn.Module):
...
@@ -294,12 +297,16 @@ class ExtraMSABlock(nn.Module):
m
[
0
]
=
torch
.
nn
.
functional
.
pad
(
m
[
0
]
=
torch
.
nn
.
functional
.
pad
(
m
[
0
],
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_cnt_padding_size
)
m
[
0
],
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_cnt_padding_size
)
)
)
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_len_padding_size
)
z
[
0
],
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_len_padding_size
)
)
)
torch
.
cuda
.
empty_cache
()
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
)
if
not
self
.
is_multimer
else
scatter
(
m
[
0
],
dim
=
2
)
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
,
drop_unused
=
True
)
if
not
self
.
is_multimer
else
scatter
(
m
[
0
],
dim
=
2
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
)
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
...
@@ -332,8 +339,9 @@ class ExtraMSABlock(nn.Module):
...
@@ -332,8 +339,9 @@ class ExtraMSABlock(nn.Module):
if
self
.
last_block
:
if
self
.
last_block
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
)
if
not
self
.
is_multimer
else
gather
(
m
[
0
],
dim
=
2
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
,
chunks
=
4
)
if
not
self
.
is_multimer
else
gather
(
m
[
0
],
dim
=
2
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
)
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
,
chunks
=
4
)
m
[
0
]
=
m
[
0
][:,
:
-
seq_cnt_padding_size
,
:
-
seq_len_padding_size
,
:]
m
[
0
]
=
m
[
0
][:,
:
-
seq_cnt_padding_size
,
:
-
seq_len_padding_size
,
:]
z
[
0
]
=
z
[
0
][:,
:
-
seq_len_padding_size
,
:
-
seq_len_padding_size
,
:]
z
[
0
]
=
z
[
0
][:,
:
-
seq_len_padding_size
,
:
-
seq_len_padding_size
,
:]
...
...
fastfold/model/fastnn/ops.py
View file @
e5d72d41
...
@@ -207,7 +207,9 @@ class OutProductMean(nn.Module):
...
@@ -207,7 +207,9 @@ class OutProductMean(nn.Module):
norm
=
torch
.
einsum
(
'bsid,bsjd->bijd'
,
M_mask_col
,
M_mask
)
+
1e-3
norm
=
torch
.
einsum
(
'bsid,bsjd->bijd'
,
M_mask_col
,
M_mask
)
+
1e-3
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
torch
.
cuda
.
empty_cache
()
right_act_all
=
M_mask
*
right_act_all
right_act_all
=
M_mask
*
right_act_all
torch
.
cuda
.
empty_cache
()
para_dim
=
left_act
.
shape
[
2
]
para_dim
=
left_act
.
shape
[
2
]
chunk_size
=
CHUNK_SIZE
chunk_size
=
CHUNK_SIZE
...
@@ -1272,11 +1274,14 @@ class InputEmbedder(nn.Module):
...
@@ -1272,11 +1274,14 @@ class InputEmbedder(nn.Module):
reshaped_bins
=
boundaries
.
view
(((
1
,)
*
len
(
d
.
shape
))
+
(
len
(
boundaries
),))
reshaped_bins
=
boundaries
.
view
(((
1
,)
*
len
(
d
.
shape
))
+
(
len
(
boundaries
),))
pair_emb
=
d
[...,
None
]
-
reshaped_bins
pair_emb
=
d
[...,
None
]
-
reshaped_bins
del
d
,
reshaped_bins
torch
.
cuda
.
empty_cache
()
pair_emb
=
torch
.
argmin
(
torch
.
abs
(
pair_emb
),
dim
=-
1
)
pair_emb
=
torch
.
argmin
(
torch
.
abs
(
pair_emb
),
dim
=-
1
)
pair_emb
=
nn
.
functional
.
one_hot
(
pair_emb
,
num_classes
=
len
(
boundaries
)).
float
().
type
(
ri
.
dtype
)
pair_emb
=
nn
.
functional
.
one_hot
(
pair_emb
,
num_classes
=
len
(
boundaries
)).
float
().
type
(
ri
.
dtype
)
pair_emb
=
self
.
linear_relpos
(
pair_emb
)
pair_emb
=
self
.
linear_relpos
(
pair_emb
)
pair_emb
+=
tf_emb_i
[...,
None
,
:]
pair_emb
+=
tf_emb_i
[...,
None
,
:]
pair_emb
+=
tf_emb_j
[...,
None
,
:,
:]
pair_emb
+=
tf_emb_j
[...,
None
,
:,
:]
torch
.
cuda
.
empty_cache
()
# [*, N_clust, N_res, c_m]
# [*, N_clust, N_res, c_m]
n_clust
=
msa
.
shape
[
-
3
]
n_clust
=
msa
.
shape
[
-
3
]
...
...
fastfold/model/fastnn/template.py
View file @
e5d72d41
...
@@ -282,9 +282,11 @@ class TemplatePairBlock(nn.Module):
...
@@ -282,9 +282,11 @@ class TemplatePairBlock(nn.Module):
seq_length
=
mask
.
size
(
-
1
)
seq_length
=
mask
.
size
(
-
1
)
padding_size
=
(
int
(
seq_length
/
dap_size
)
+
1
)
*
dap_size
-
seq_length
padding_size
=
(
int
(
seq_length
/
dap_size
)
+
1
)
*
dap_size
-
seq_length
torch
.
cuda
.
empty_cache
()
if
self
.
first_block
:
if
self
.
first_block
:
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
...
@@ -294,8 +296,8 @@ class TemplatePairBlock(nn.Module):
...
@@ -294,8 +296,8 @@ class TemplatePairBlock(nn.Module):
single
=
z
[
0
][
i
].
unsqueeze
(
-
4
).
to
(
mask
.
device
)
single
=
z
[
0
][
i
].
unsqueeze
(
-
4
).
to
(
mask
.
device
)
single_mask
=
mask
[
i
].
unsqueeze
(
-
3
)
single_mask
=
mask
[
i
].
unsqueeze
(
-
3
)
single_mask_row
=
scatter
(
single_mask
,
dim
=
1
)
single_mask_row
=
scatter
(
single_mask
,
dim
=
1
,
drop_unused
=
True
)
single_mask_col
=
scatter
(
single_mask
,
dim
=
2
)
single_mask_col
=
scatter
(
single_mask
,
dim
=
2
,
drop_unused
=
True
)
single
=
self
.
TriangleAttentionStartingNode
(
single
,
single_mask_row
)
single
=
self
.
TriangleAttentionStartingNode
(
single
,
single_mask_row
)
single
=
row_to_col
(
single
)
single
=
row_to_col
(
single
)
...
@@ -307,6 +309,7 @@ class TemplatePairBlock(nn.Module):
...
@@ -307,6 +309,7 @@ class TemplatePairBlock(nn.Module):
single
=
self
.
PairTransition
(
single
)
single
=
self
.
PairTransition
(
single
)
single
=
col_to_row
(
single
)
single
=
col_to_row
(
single
)
z
[
0
][
i
]
=
single
.
to
(
z
[
0
].
device
)
z
[
0
][
i
]
=
single
.
to
(
z
[
0
].
device
)
# z = torch.cat(single_templates, dim=-4)
# z = torch.cat(single_templates, dim=-4)
if
self
.
last_block
:
if
self
.
last_block
:
...
...
fastfold/model/hub/alphafold.py
View file @
e5d72d41
...
@@ -169,7 +169,8 @@ class AlphaFold(nn.Module):
...
@@ -169,7 +169,8 @@ class AlphaFold(nn.Module):
return
ret
return
ret
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
_recycle
=
True
):
def
iteration
(
self
,
feats
,
prevs
,
_recycle
=
True
,
no_iter
=
0
):
torch
.
cuda
.
empty_cache
()
# Primary output dictionary
# Primary output dictionary
outputs
=
{}
outputs
=
{}
...
@@ -203,7 +204,10 @@ class AlphaFold(nn.Module):
...
@@ -203,7 +204,10 @@ class AlphaFold(nn.Module):
if
not
self
.
globals
.
is_multimer
if
not
self
.
globals
.
is_multimer
else
self
.
input_embedder
(
feats
)
else
self
.
input_embedder
(
feats
)
)
)
torch
.
cuda
.
empty_cache
()
m_1_prev
,
z_prev
,
x_prev
=
reversed
([
prevs
.
pop
()
for
_
in
range
(
3
)])
# Initialize the recycling embeddings, if needs be
# Initialize the recycling embeddings, if needs be
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
# [*, N, C_m]
# [*, N, C_m]
...
@@ -251,6 +255,8 @@ class AlphaFold(nn.Module):
...
@@ -251,6 +255,8 @@ class AlphaFold(nn.Module):
# Possibly prevents memory fragmentation
# Possibly prevents memory fragmentation
del
m_1_prev
,
z_prev
,
x_prev
del
m_1_prev
,
z_prev
,
x_prev
torch
.
cuda
.
empty_cache
()
# Embed the templates + merge with MSA/pair embeddings
# Embed the templates + merge with MSA/pair embeddings
if
self
.
config
.
template
.
enabled
:
if
self
.
config
.
template
.
enabled
:
...
@@ -330,6 +336,8 @@ class AlphaFold(nn.Module):
...
@@ -330,6 +336,8 @@ class AlphaFold(nn.Module):
# [*, S_e, N, C_e]
# [*, S_e, N, C_e]
extra_msa_feat
=
extra_msa_fn
(
feats
)
extra_msa_feat
=
extra_msa_fn
(
feats
)
extra_msa_feat
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
extra_msa_feat
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
torch
.
cuda
.
empty_cache
()
# [*, N, N, C_z]
# [*, N, N, C_z]
if
not
self
.
globals
.
inplace
:
if
not
self
.
globals
.
inplace
:
...
@@ -353,6 +361,8 @@ class AlphaFold(nn.Module):
...
@@ -353,6 +361,8 @@ class AlphaFold(nn.Module):
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)[
0
]
)[
0
]
del
extra_msa_feat
,
extra_msa_fn
del
extra_msa_feat
,
extra_msa_fn
torch
.
cuda
.
empty_cache
()
# Run MSA + pair embeddings through the trunk of the network
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# m: [*, S, N, C_m]
...
@@ -380,36 +390,49 @@ class AlphaFold(nn.Module):
...
@@ -380,36 +390,49 @@ class AlphaFold(nn.Module):
)
)
m
=
m
[
0
]
m
=
m
[
0
]
z
=
z
[
0
]
z
=
z
[
0
]
torch
.
cuda
.
empty_cache
()
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
if
no_iter
==
3
:
outputs
[
"pair"
]
=
z
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"single"
]
=
s
outputs
[
"pair"
]
=
z
outputs
[
"single"
]
=
s
# Predict 3D structure
# Predict 3D structure
outputs
[
"sm"
]
=
self
.
structure_module
(
z
=
[
z
]
outputs_sm
=
self
.
structure_module
(
s
,
s
,
z
,
z
,
feats
[
"aatype"
],
feats
[
"aatype"
],
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
)
)
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
torch
.
cuda
.
empty_cache
()
)
outputs
[
"final_atom_mask"
]
=
feats
[
"atom37_atom_exists"
]
if
no_iter
==
3
:
outputs
[
"final_affine_tensor"
]
=
outputs
[
"sm"
][
"frames"
][
-
1
]
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
outputs
[
"sm"
]
=
outputs_sm
# Save embeddings for use during the next recycling iteration
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
# [*, N, C_m]
)
m_1_prev
=
m
[...,
0
,
:,
:]
outputs
[
"final_atom_mask"
]
=
feats
[
"atom37_atom_exists"
]
outputs
[
"final_affine_tensor"
]
=
outputs
[
"sm"
][
"frames"
][
-
1
]
else
:
# Save embeddings for use during the next recycling iteration
# [*, N,
N,
C_
z
]
# [*, N, C_
m
]
z
_prev
=
z
m_1
_prev
=
m
[...,
0
,
:,
:]
# [*, N,
3
]
# [*, N,
N, C_z
]
x
_prev
=
outputs
[
"final_atom_positions"
]
z
_prev
=
z
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
# [*, N, 3]
x_prev
=
outputs
[
"final_atom_positions"
]
if
no_iter
!=
3
:
return
None
,
m_1_prev
,
z_prev
,
x_prev
else
:
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
def
_disable_activation_checkpointing
(
self
):
def
_disable_activation_checkpointing
(
self
):
self
.
template_embedder
.
template_pair_stack
.
blocks_per_ckpt
=
None
self
.
template_embedder
.
template_pair_stack
.
blocks_per_ckpt
=
None
...
@@ -482,6 +505,7 @@ class AlphaFold(nn.Module):
...
@@ -482,6 +505,7 @@ class AlphaFold(nn.Module):
"""
"""
# Initialize recycling embeddings
# Initialize recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
# Disable activation checkpointing for the first few recycling iters
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled
=
torch
.
is_grad_enabled
()
is_grad_enabled
=
torch
.
is_grad_enabled
()
...
@@ -506,11 +530,14 @@ class AlphaFold(nn.Module):
...
@@ -506,11 +530,14 @@ class AlphaFold(nn.Module):
# Run the next iteration of the model
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
feats
,
m_1_prev
,
prevs
,
z_prev
,
_recycle
=
(
num_iters
>
1
),
x_prev
,
no_iter
=
cycle_no
_recycle
=
(
num_iters
>
1
)
)
)
if
cycle_no
!=
3
:
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
del
m_1_prev
,
z_prev
,
x_prev
# Run auxiliary heads
# Run auxiliary heads
outputs
.
update
(
self
.
aux_heads
(
outputs
))
outputs
.
update
(
self
.
aux_heads
(
outputs
))
...
...
fastfold/model/nn/structure_module.py
View file @
e5d72d41
...
@@ -389,6 +389,8 @@ class InvariantPointAttention(nn.Module):
...
@@ -389,6 +389,8 @@ class InvariantPointAttention(nn.Module):
)
)
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
+=
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
))
a
+=
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
))
torch
.
cuda
.
empty_cache
()
if
self
.
is_multimer
:
if
self
.
is_multimer
:
# [*, N_res, N_res, H, P_q, 3]
# [*, N_res, N_res, H, P_q, 3]
...
@@ -396,11 +398,27 @@ class InvariantPointAttention(nn.Module):
...
@@ -396,11 +398,27 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H, P_q]
# [*, N_res, N_res, H, P_q]
pt_att
=
sum
([
c
**
2
for
c
in
pt_att
])
pt_att
=
sum
([
c
**
2
for
c
in
pt_att
])
else
:
else
:
# [*, N_res, N_res, H, P_q, 3]
# # [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
# pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att
=
pt_att
**
2
# pt_att = pt_att**2
# [*, N_res, N_res, H, P_q]
# # [*, N_res, N_res, H, P_q]
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
# pt_att = sum(torch.unbind(pt_att, dim=-1))
_chunks
=
10
_ks
=
k_pts
.
unsqueeze
(
-
5
)
_qlist
=
torch
.
chunk
(
q_pts
.
unsqueeze
(
-
4
),
chunks
=
_chunks
,
dim
=
0
)
pt_att
=
None
for
_i
in
range
(
_chunks
):
_pt
=
_qlist
[
_i
]
-
_ks
_pt
=
_pt
**
2
_pt
=
sum
(
torch
.
unbind
(
_pt
,
dim
=-
1
))
if
_i
==
0
:
pt_att
=
_pt
else
:
pt_att
=
torch
.
cat
([
pt_att
,
_pt
],
dim
=
0
)
torch
.
cuda
.
empty_cache
()
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
...
@@ -430,6 +448,8 @@ class InvariantPointAttention(nn.Module):
...
@@ -430,6 +448,8 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden]
# [*, N_res, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
o
=
flatten_final_dims
(
o
,
2
)
torch
.
cuda
.
empty_cache
()
# As DeepMind explains, this manual matmul ensures that the operation
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# happens in float32.
...
@@ -447,14 +467,34 @@ class InvariantPointAttention(nn.Module):
...
@@ -447,14 +467,34 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v]
# [*, N_res, H * P_v]
o_pt_norm
=
o_pt
.
norm
(
self
.
eps
)
o_pt_norm
=
o_pt
.
norm
(
self
.
eps
)
else
:
else
:
# [*, H, 3, N_res, P_v]
# chunk permuted_pts
o_pt
=
torch
.
sum
(
permuted_pts
=
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
(
a
[...,
None
,
:,
:,
None
]
size0
=
permuted_pts
.
size
()[
0
]
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
a_size0
=
a
.
size
()[
0
]
),
if
size0
==
1
or
size0
!=
a_size0
:
dim
=-
2
,
# # [*, H, 3, N_res, P_v]
)
o_pt
=
torch
.
sum
(
(
a
[...,
None
,
:,
:,
None
]
*
permuted_pts
),
dim
=-
2
,
)
else
:
a_lists
=
torch
.
chunk
(
a
[...,
None
,
:,
:,
None
],
size0
,
dim
=
0
)
permuted_pts_lists
=
torch
.
chunk
(
permuted_pts
,
size0
,
dim
=
0
)
_c
=
None
for
i
in
range
(
size0
):
_d
=
a_lists
[
i
]
*
permuted_pts_lists
[
i
]
_d
=
torch
.
sum
(
_d
[...,
None
,
:,
:],
dim
=-
2
)
if
i
==
0
:
_c
=
_d
else
:
_c
=
torch
.
cat
([
_c
,
_d
],
dim
=
0
)
o_pt
=
torch
.
sum
(
_c
,
dim
=-
2
)
del
permuted_pts
# [*, N_res, H, P_v, 3]
# [*, N_res, H, P_v, 3]
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
...
@@ -471,8 +511,11 @@ class InvariantPointAttention(nn.Module):
...
@@ -471,8 +511,11 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, C_z]
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
))
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
))
torch
.
cuda
.
empty_cache
()
# [*, N_res, H * C_z]
# [*, N_res, H * C_z]
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
del
a
# [*, N_res, C_s]
# [*, N_res, C_s]
if
self
.
is_multimer
:
if
self
.
is_multimer
:
...
@@ -485,6 +528,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -485,6 +528,7 @@ class InvariantPointAttention(nn.Module):
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
),
o_pt_norm
,
o_pair
),
dim
=-
1
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
),
o_pt_norm
,
o_pair
),
dim
=-
1
).
to
(
dtype
=
z
.
dtype
)
).
to
(
dtype
=
z
.
dtype
)
)
)
torch
.
cuda
.
empty_cache
()
return
s
return
s
...
@@ -720,10 +764,16 @@ class StructureModule(nn.Module):
...
@@ -720,10 +764,16 @@ class StructureModule(nn.Module):
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
self
.
layer_norm_z
(
z
)
z
=
self
.
layer_norm_z
(
z
)
# inplace z
# z[0] = z[0].contiguous()
# torch.cuda.emtpy_cache()
# z[0] = self.layer_norm_z(z[0])
# [*, N, C_s]
# [*, N, C_s]
s_initial
=
s
s_initial
=
s
s
=
self
.
linear_in
(
s
)
s
=
self
.
linear_in
(
s
)
torch
.
cuda
.
empty_cache
()
# [*, N]
# [*, N]
rigids
=
Rigid
.
identity
(
rigids
=
Rigid
.
identity
(
...
@@ -737,9 +787,12 @@ class StructureModule(nn.Module):
...
@@ -737,9 +787,12 @@ class StructureModule(nn.Module):
for
i
in
range
(
self
.
no_blocks
):
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
del
z
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
ipa_dropout
(
s
)
torch
.
cuda
.
empty_cache
()
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
s
=
self
.
transition
(
s
)
torch
.
cuda
.
empty_cache
()
# [*, N]
# [*, N]
rigids
=
rigids
.
compose_q_update_vec
(
self
.
bb_update
(
s
))
rigids
=
rigids
.
compose_q_update_vec
(
self
.
bb_update
(
s
))
...
...
inference.py
View file @
e5d72d41
...
@@ -434,21 +434,21 @@ def inference_monomer_model(args):
...
@@ -434,21 +434,21 @@ def inference_monomer_model(args):
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
amber_relaxer
=
relax
.
AmberRelaxation
(
#
amber_relaxer = relax.AmberRelaxation(
use_gpu
=
True
,
#
use_gpu=True,
**
config
.
relax
,
#
**config.relax,
)
#
)
# Relax the prediction.
#
#
Relax the prediction.
t
=
time
.
perf_counter
()
#
t = time.perf_counter()
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
#
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
#
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
#
#
Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
#
relaxed_output_path = os.path.join(args.output_dir,
f
'
{
tag
}
_
{
args
.
model_name
}
_relaxed.pdb'
)
#
f'{tag}_{args.model_name}_relaxed.pdb')
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
#
with open(relaxed_output_path, 'w') as f:
f
.
write
(
relaxed_pdb_str
)
#
f.write(relaxed_pdb_str)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
inference.sh
View file @
e5d72d41
...
@@ -14,4 +14,6 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files \
...
@@ -14,4 +14,6 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files \
--jackhmmer_binary_path
`
which jackhmmer
`
\
--jackhmmer_binary_path
`
which jackhmmer
`
\
--hhblits_binary_path
`
which hhblits
`
\
--hhblits_binary_path
`
which hhblits
`
\
--hhsearch_binary_path
`
which hhsearch
`
\
--hhsearch_binary_path
`
which hhsearch
`
\
--kalign_binary_path
`
which kalign
`
--kalign_binary_path
`
which kalign
`
\
--chunk_size
4
\
--inplace
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment