Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
e5d72d41
"...en/git@developer.sourcefind.cn:chenpangpang/diffusers.git" did not exist on "b1a2c0d5774dae8b5c5a2bc45eb9ec5bd7d485a5"
Commit
e5d72d41
authored
Dec 07, 2022
by
zhuww
Browse files
infer longer sequences using memory optimization
parent
52919c63
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
231 additions
and
69 deletions
+231
-69
fastfold/distributed/comm.py
fastfold/distributed/comm.py
+61
-4
fastfold/model/fastnn/embedders.py
fastfold/model/fastnn/embedders.py
+2
-0
fastfold/model/fastnn/evoformer.py
fastfold/model/fastnn/evoformer.py
+9
-4
fastfold/model/fastnn/msa.py
fastfold/model/fastnn/msa.py
+13
-5
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+5
-0
fastfold/model/fastnn/template.py
fastfold/model/fastnn/template.py
+6
-3
fastfold/model/hub/alphafold.py
fastfold/model/hub/alphafold.py
+51
-24
fastfold/model/nn/structure_module.py
fastfold/model/nn/structure_module.py
+66
-13
inference.py
inference.py
+15
-15
inference.sh
inference.sh
+3
-1
No files found.
fastfold/distributed/comm.py
View file @
e5d72d41
...
@@ -27,14 +27,19 @@ def _reduce(tensor: Tensor) -> Tensor:
...
@@ -27,14 +27,19 @@ def _reduce(tensor: Tensor) -> Tensor:
return
tensor
return
tensor
def
_split
(
tensor
:
Tensor
,
dim
:
int
=
-
1
)
->
Tensor
:
def
_split
(
tensor
:
Tensor
,
dim
:
int
=
-
1
,
drop_unused
=
False
)
->
Tensor
:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
return
tensor
return
tensor
split_size
=
divide
(
tensor
.
shape
[
dim
],
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))
split_size
=
divide
(
tensor
.
shape
[
dim
],
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
))
tensor_list
=
torch
.
split
(
tensor
,
split_size
,
dim
=
dim
)
tensor_list
=
torch
.
split
(
tensor
,
split_size
,
dim
=
dim
)
output
=
tensor_list
[
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)].
contiguous
()
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
if
not
drop_unused
:
output
=
tensor_list
[
rank
].
contiguous
()
else
:
output
=
tensor_list
[
rank
].
contiguous
().
clone
()
return
output
return
output
...
@@ -65,6 +70,55 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
...
@@ -65,6 +70,55 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
return
output
return
output
def
_chunk_gather
(
tensor
:
Tensor
,
dim
=-
1
,
chunks
=
1
)
->
Tensor
:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
return
tensor
if
dim
==
1
and
list
(
tensor
.
shape
)[
0
]
==
1
:
output_shape
=
list
(
tensor
.
shape
)
output_shape
[
1
]
*=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
output
=
torch
.
empty
(
output_shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
1
)
tensor_list
=
[]
for
t
in
world_list
:
tensor_list
.
extend
(
t
.
chunk
(
chunks
,
dim
=
1
))
chunk_tensor
=
tensor
.
chunk
(
chunks
,
dim
=
1
)
for
i
in
range
(
chunks
):
_chunk_list
=
[
tensor_list
[
j
*
4
+
i
]
for
j
in
range
(
4
)]
_chunk_tensor
=
chunk_tensor
[
i
]
dist
.
all_gather
(
list
(
_chunk_list
),
_chunk_tensor
,
group
=
gpc
.
get_group
(
ParallelMode
.
TENSOR
),
async_op
=
False
)
torch
.
cuda
.
empty_cache
()
else
:
output_shape
=
list
(
tensor
.
shape
)
output_shape
[
0
]
*=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
output
=
torch
.
empty
(
output_shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
world_list
=
output
.
chunk
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
),
dim
=
0
)
tensor_list
=
[]
for
t
in
world_list
:
tensor_list
.
extend
(
t
.
chunk
(
chunks
,
dim
=
0
))
chunk_tensor
=
tensor
.
chunk
(
chunks
,
dim
=
0
)
for
i
in
range
(
chunks
):
_chunk_list
=
[
tensor_list
[
j
*
4
+
i
]
for
j
in
range
(
4
)]
_chunk_tensor
=
chunk_tensor
[
i
]
dist
.
all_gather
(
list
(
_chunk_list
),
_chunk_tensor
,
group
=
gpc
.
get_group
(
ParallelMode
.
TENSOR
),
async_op
=
False
)
torch
.
cuda
.
empty_cache
()
return
output
def
copy
(
input
:
Tensor
)
->
Tensor
:
def
copy
(
input
:
Tensor
)
->
Tensor
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
input
=
Copy
.
apply
(
input
)
input
=
Copy
.
apply
(
input
)
...
@@ -122,11 +176,14 @@ class Reduce(torch.autograd.Function):
...
@@ -122,11 +176,14 @@ class Reduce(torch.autograd.Function):
return
grad_output
return
grad_output
def
gather
(
input
:
Tensor
,
dim
:
int
=
-
1
)
->
Tensor
:
def
gather
(
input
:
Tensor
,
dim
:
int
=
-
1
,
chunks
:
int
=
None
)
->
Tensor
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
if
torch
.
is_grad_enabled
()
and
input
.
requires_grad
:
input
=
Gather
.
apply
(
input
,
dim
)
input
=
Gather
.
apply
(
input
,
dim
)
else
:
else
:
input
=
_gather
(
input
,
dim
=
dim
)
if
chunks
is
None
:
input
=
_gather
(
input
,
dim
=
dim
)
else
:
input
=
_chunk_gather
(
input
,
dim
=
dim
,
chunks
=
chunks
)
return
input
return
input
...
...
fastfold/model/fastnn/embedders.py
View file @
e5d72d41
...
@@ -224,6 +224,7 @@ class TemplateEmbedder(nn.Module):
...
@@ -224,6 +224,7 @@ class TemplateEmbedder(nn.Module):
).
to
(
t
.
device
)
).
to
(
t
.
device
)
del
tt
,
single_template_feats
del
tt
,
single_template_feats
torch
.
cuda
.
empty_cache
()
template_embeds
=
dict_multimap
(
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
partial
(
torch
.
cat
,
dim
=
templ_dim
),
...
@@ -245,6 +246,7 @@ class TemplateEmbedder(nn.Module):
...
@@ -245,6 +246,7 @@ class TemplateEmbedder(nn.Module):
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
*
256
if
chunk_size
is
not
None
else
chunk_size
,
chunk_size
=
chunk_size
*
256
if
chunk_size
is
not
None
else
chunk_size
,
)
)
torch
.
cuda
.
empty_cache
()
ret
=
{}
ret
=
{}
ret
[
"template_pair_embedding"
]
=
z
ret
[
"template_pair_embedding"
]
=
z
...
...
fastfold/model/fastnn/evoformer.py
View file @
e5d72d41
...
@@ -105,9 +105,13 @@ class Evoformer(nn.Module):
...
@@ -105,9 +105,13 @@ class Evoformer(nn.Module):
m
[
0
]
=
torch
.
nn
.
functional
.
pad
(
m
[
0
],
(
0
,
0
,
0
,
padding_size
))
m
[
0
]
=
torch
.
nn
.
functional
.
pad
(
m
[
0
],
(
0
,
0
,
0
,
padding_size
))
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
torch
.
cuda
.
empty_cache
()
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
)
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
,
drop_unused
=
True
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
)
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
...
@@ -137,9 +141,10 @@ class Evoformer(nn.Module):
...
@@ -137,9 +141,10 @@ class Evoformer(nn.Module):
if
self
.
last_block
:
if
self
.
last_block
:
m
[
0
]
=
m
[
0
].
squeeze
(
0
)
m
[
0
]
=
m
[
0
].
squeeze
(
0
)
z
[
0
]
=
z
[
0
].
squeeze
(
0
)
z
[
0
]
=
z
[
0
].
squeeze
(
0
)
torch
.
cuda
.
empty_cache
()
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
,
chunks
=
4
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
,
chunks
=
4
)
m
[
0
]
=
m
[
0
][:,
:
-
padding_size
,
:]
m
[
0
]
=
m
[
0
][:,
:
-
padding_size
,
:]
z
[
0
]
=
z
[
0
][:
-
padding_size
,
:
-
padding_size
,
:]
z
[
0
]
=
z
[
0
][:
-
padding_size
,
:
-
padding_size
,
:]
...
...
fastfold/model/fastnn/msa.py
View file @
e5d72d41
...
@@ -217,7 +217,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -217,7 +217,7 @@ class ExtraMSABlock(nn.Module):
seq_len
=
pair_mask
.
size
(
-
1
)
seq_len
=
pair_mask
.
size
(
-
1
)
seq_cnt_padding_size
=
(
int
(
seq_cnt
/
dap_size
)
+
1
)
*
dap_size
-
seq_cnt
seq_cnt_padding_size
=
(
int
(
seq_cnt
/
dap_size
)
+
1
)
*
dap_size
-
seq_cnt
seq_len_padding_size
=
(
int
(
seq_len
/
dap_size
)
+
1
)
*
dap_size
-
seq_len
seq_len_padding_size
=
(
int
(
seq_len
/
dap_size
)
+
1
)
*
dap_size
-
seq_len
if
self
.
first_block
:
if
self
.
first_block
:
m
=
m
.
unsqueeze
(
0
)
m
=
m
.
unsqueeze
(
0
)
z
=
z
.
unsqueeze
(
0
)
z
=
z
.
unsqueeze
(
0
)
...
@@ -225,6 +225,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -225,6 +225,7 @@ class ExtraMSABlock(nn.Module):
m
=
torch
.
nn
.
functional
.
pad
(
m
=
torch
.
nn
.
functional
.
pad
(
m
,
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_cnt_padding_size
)
m
,
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_cnt_padding_size
)
)
)
z
=
torch
.
nn
.
functional
.
pad
(
z
=
torch
.
nn
.
functional
.
pad
(
z
,
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_len_padding_size
)
z
,
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_len_padding_size
)
)
)
...
@@ -286,6 +287,8 @@ class ExtraMSABlock(nn.Module):
...
@@ -286,6 +287,8 @@ class ExtraMSABlock(nn.Module):
seq_len
=
pair_mask
.
size
(
-
1
)
seq_len
=
pair_mask
.
size
(
-
1
)
seq_cnt_padding_size
=
(
int
(
seq_cnt
/
dap_size
)
+
1
)
*
dap_size
-
seq_cnt
seq_cnt_padding_size
=
(
int
(
seq_cnt
/
dap_size
)
+
1
)
*
dap_size
-
seq_cnt
seq_len_padding_size
=
(
int
(
seq_len
/
dap_size
)
+
1
)
*
dap_size
-
seq_len
seq_len_padding_size
=
(
int
(
seq_len
/
dap_size
)
+
1
)
*
dap_size
-
seq_len
torch
.
cuda
.
empty_cache
()
if
self
.
first_block
:
if
self
.
first_block
:
m
[
0
]
=
m
[
0
].
unsqueeze
(
0
)
m
[
0
]
=
m
[
0
].
unsqueeze
(
0
)
...
@@ -294,12 +297,16 @@ class ExtraMSABlock(nn.Module):
...
@@ -294,12 +297,16 @@ class ExtraMSABlock(nn.Module):
m
[
0
]
=
torch
.
nn
.
functional
.
pad
(
m
[
0
]
=
torch
.
nn
.
functional
.
pad
(
m
[
0
],
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_cnt_padding_size
)
m
[
0
],
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_cnt_padding_size
)
)
)
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_len_padding_size
)
z
[
0
],
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_len_padding_size
)
)
)
torch
.
cuda
.
empty_cache
()
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
)
if
not
self
.
is_multimer
else
scatter
(
m
[
0
],
dim
=
2
)
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
,
drop_unused
=
True
)
if
not
self
.
is_multimer
else
scatter
(
m
[
0
],
dim
=
2
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
)
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
...
@@ -332,8 +339,9 @@ class ExtraMSABlock(nn.Module):
...
@@ -332,8 +339,9 @@ class ExtraMSABlock(nn.Module):
if
self
.
last_block
:
if
self
.
last_block
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
)
if
not
self
.
is_multimer
else
gather
(
m
[
0
],
dim
=
2
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
,
chunks
=
4
)
if
not
self
.
is_multimer
else
gather
(
m
[
0
],
dim
=
2
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
)
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
,
chunks
=
4
)
m
[
0
]
=
m
[
0
][:,
:
-
seq_cnt_padding_size
,
:
-
seq_len_padding_size
,
:]
m
[
0
]
=
m
[
0
][:,
:
-
seq_cnt_padding_size
,
:
-
seq_len_padding_size
,
:]
z
[
0
]
=
z
[
0
][:,
:
-
seq_len_padding_size
,
:
-
seq_len_padding_size
,
:]
z
[
0
]
=
z
[
0
][:,
:
-
seq_len_padding_size
,
:
-
seq_len_padding_size
,
:]
...
...
fastfold/model/fastnn/ops.py
View file @
e5d72d41
...
@@ -207,7 +207,9 @@ class OutProductMean(nn.Module):
...
@@ -207,7 +207,9 @@ class OutProductMean(nn.Module):
norm
=
torch
.
einsum
(
'bsid,bsjd->bijd'
,
M_mask_col
,
M_mask
)
+
1e-3
norm
=
torch
.
einsum
(
'bsid,bsjd->bijd'
,
M_mask_col
,
M_mask
)
+
1e-3
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
torch
.
cuda
.
empty_cache
()
right_act_all
=
M_mask
*
right_act_all
right_act_all
=
M_mask
*
right_act_all
torch
.
cuda
.
empty_cache
()
para_dim
=
left_act
.
shape
[
2
]
para_dim
=
left_act
.
shape
[
2
]
chunk_size
=
CHUNK_SIZE
chunk_size
=
CHUNK_SIZE
...
@@ -1272,11 +1274,14 @@ class InputEmbedder(nn.Module):
...
@@ -1272,11 +1274,14 @@ class InputEmbedder(nn.Module):
reshaped_bins
=
boundaries
.
view
(((
1
,)
*
len
(
d
.
shape
))
+
(
len
(
boundaries
),))
reshaped_bins
=
boundaries
.
view
(((
1
,)
*
len
(
d
.
shape
))
+
(
len
(
boundaries
),))
pair_emb
=
d
[...,
None
]
-
reshaped_bins
pair_emb
=
d
[...,
None
]
-
reshaped_bins
del
d
,
reshaped_bins
torch
.
cuda
.
empty_cache
()
pair_emb
=
torch
.
argmin
(
torch
.
abs
(
pair_emb
),
dim
=-
1
)
pair_emb
=
torch
.
argmin
(
torch
.
abs
(
pair_emb
),
dim
=-
1
)
pair_emb
=
nn
.
functional
.
one_hot
(
pair_emb
,
num_classes
=
len
(
boundaries
)).
float
().
type
(
ri
.
dtype
)
pair_emb
=
nn
.
functional
.
one_hot
(
pair_emb
,
num_classes
=
len
(
boundaries
)).
float
().
type
(
ri
.
dtype
)
pair_emb
=
self
.
linear_relpos
(
pair_emb
)
pair_emb
=
self
.
linear_relpos
(
pair_emb
)
pair_emb
+=
tf_emb_i
[...,
None
,
:]
pair_emb
+=
tf_emb_i
[...,
None
,
:]
pair_emb
+=
tf_emb_j
[...,
None
,
:,
:]
pair_emb
+=
tf_emb_j
[...,
None
,
:,
:]
torch
.
cuda
.
empty_cache
()
# [*, N_clust, N_res, c_m]
# [*, N_clust, N_res, c_m]
n_clust
=
msa
.
shape
[
-
3
]
n_clust
=
msa
.
shape
[
-
3
]
...
...
fastfold/model/fastnn/template.py
View file @
e5d72d41
...
@@ -282,9 +282,11 @@ class TemplatePairBlock(nn.Module):
...
@@ -282,9 +282,11 @@ class TemplatePairBlock(nn.Module):
seq_length
=
mask
.
size
(
-
1
)
seq_length
=
mask
.
size
(
-
1
)
padding_size
=
(
int
(
seq_length
/
dap_size
)
+
1
)
*
dap_size
-
seq_length
padding_size
=
(
int
(
seq_length
/
dap_size
)
+
1
)
*
dap_size
-
seq_length
torch
.
cuda
.
empty_cache
()
if
self
.
first_block
:
if
self
.
first_block
:
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
...
@@ -294,8 +296,8 @@ class TemplatePairBlock(nn.Module):
...
@@ -294,8 +296,8 @@ class TemplatePairBlock(nn.Module):
single
=
z
[
0
][
i
].
unsqueeze
(
-
4
).
to
(
mask
.
device
)
single
=
z
[
0
][
i
].
unsqueeze
(
-
4
).
to
(
mask
.
device
)
single_mask
=
mask
[
i
].
unsqueeze
(
-
3
)
single_mask
=
mask
[
i
].
unsqueeze
(
-
3
)
single_mask_row
=
scatter
(
single_mask
,
dim
=
1
)
single_mask_row
=
scatter
(
single_mask
,
dim
=
1
,
drop_unused
=
True
)
single_mask_col
=
scatter
(
single_mask
,
dim
=
2
)
single_mask_col
=
scatter
(
single_mask
,
dim
=
2
,
drop_unused
=
True
)
single
=
self
.
TriangleAttentionStartingNode
(
single
,
single_mask_row
)
single
=
self
.
TriangleAttentionStartingNode
(
single
,
single_mask_row
)
single
=
row_to_col
(
single
)
single
=
row_to_col
(
single
)
...
@@ -307,6 +309,7 @@ class TemplatePairBlock(nn.Module):
...
@@ -307,6 +309,7 @@ class TemplatePairBlock(nn.Module):
single
=
self
.
PairTransition
(
single
)
single
=
self
.
PairTransition
(
single
)
single
=
col_to_row
(
single
)
single
=
col_to_row
(
single
)
z
[
0
][
i
]
=
single
.
to
(
z
[
0
].
device
)
z
[
0
][
i
]
=
single
.
to
(
z
[
0
].
device
)
# z = torch.cat(single_templates, dim=-4)
# z = torch.cat(single_templates, dim=-4)
if
self
.
last_block
:
if
self
.
last_block
:
...
...
fastfold/model/hub/alphafold.py
View file @
e5d72d41
...
@@ -169,7 +169,8 @@ class AlphaFold(nn.Module):
...
@@ -169,7 +169,8 @@ class AlphaFold(nn.Module):
return
ret
return
ret
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
_recycle
=
True
):
def
iteration
(
self
,
feats
,
prevs
,
_recycle
=
True
,
no_iter
=
0
):
torch
.
cuda
.
empty_cache
()
# Primary output dictionary
# Primary output dictionary
outputs
=
{}
outputs
=
{}
...
@@ -203,7 +204,10 @@ class AlphaFold(nn.Module):
...
@@ -203,7 +204,10 @@ class AlphaFold(nn.Module):
if
not
self
.
globals
.
is_multimer
if
not
self
.
globals
.
is_multimer
else
self
.
input_embedder
(
feats
)
else
self
.
input_embedder
(
feats
)
)
)
torch
.
cuda
.
empty_cache
()
m_1_prev
,
z_prev
,
x_prev
=
reversed
([
prevs
.
pop
()
for
_
in
range
(
3
)])
# Initialize the recycling embeddings, if needs be
# Initialize the recycling embeddings, if needs be
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
# [*, N, C_m]
# [*, N, C_m]
...
@@ -251,6 +255,8 @@ class AlphaFold(nn.Module):
...
@@ -251,6 +255,8 @@ class AlphaFold(nn.Module):
# Possibly prevents memory fragmentation
# Possibly prevents memory fragmentation
del
m_1_prev
,
z_prev
,
x_prev
del
m_1_prev
,
z_prev
,
x_prev
torch
.
cuda
.
empty_cache
()
# Embed the templates + merge with MSA/pair embeddings
# Embed the templates + merge with MSA/pair embeddings
if
self
.
config
.
template
.
enabled
:
if
self
.
config
.
template
.
enabled
:
...
@@ -330,6 +336,8 @@ class AlphaFold(nn.Module):
...
@@ -330,6 +336,8 @@ class AlphaFold(nn.Module):
# [*, S_e, N, C_e]
# [*, S_e, N, C_e]
extra_msa_feat
=
extra_msa_fn
(
feats
)
extra_msa_feat
=
extra_msa_fn
(
feats
)
extra_msa_feat
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
extra_msa_feat
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
torch
.
cuda
.
empty_cache
()
# [*, N, N, C_z]
# [*, N, N, C_z]
if
not
self
.
globals
.
inplace
:
if
not
self
.
globals
.
inplace
:
...
@@ -353,6 +361,8 @@ class AlphaFold(nn.Module):
...
@@ -353,6 +361,8 @@ class AlphaFold(nn.Module):
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)[
0
]
)[
0
]
del
extra_msa_feat
,
extra_msa_fn
del
extra_msa_feat
,
extra_msa_fn
torch
.
cuda
.
empty_cache
()
# Run MSA + pair embeddings through the trunk of the network
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# m: [*, S, N, C_m]
...
@@ -380,36 +390,49 @@ class AlphaFold(nn.Module):
...
@@ -380,36 +390,49 @@ class AlphaFold(nn.Module):
)
)
m
=
m
[
0
]
m
=
m
[
0
]
z
=
z
[
0
]
z
=
z
[
0
]
torch
.
cuda
.
empty_cache
()
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
if
no_iter
==
3
:
outputs
[
"pair"
]
=
z
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"single"
]
=
s
outputs
[
"pair"
]
=
z
outputs
[
"single"
]
=
s
# Predict 3D structure
# Predict 3D structure
outputs
[
"sm"
]
=
self
.
structure_module
(
z
=
[
z
]
outputs_sm
=
self
.
structure_module
(
s
,
s
,
z
,
z
,
feats
[
"aatype"
],
feats
[
"aatype"
],
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
)
)
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
torch
.
cuda
.
empty_cache
()
)
outputs
[
"final_atom_mask"
]
=
feats
[
"atom37_atom_exists"
]
if
no_iter
==
3
:
outputs
[
"final_affine_tensor"
]
=
outputs
[
"sm"
][
"frames"
][
-
1
]
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
outputs
[
"sm"
]
=
outputs_sm
# Save embeddings for use during the next recycling iteration
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
# [*, N, C_m]
)
m_1_prev
=
m
[...,
0
,
:,
:]
outputs
[
"final_atom_mask"
]
=
feats
[
"atom37_atom_exists"
]
outputs
[
"final_affine_tensor"
]
=
outputs
[
"sm"
][
"frames"
][
-
1
]
else
:
# Save embeddings for use during the next recycling iteration
# [*, N,
N,
C_
z
]
# [*, N, C_
m
]
z
_prev
=
z
m_1
_prev
=
m
[...,
0
,
:,
:]
# [*, N,
3
]
# [*, N,
N, C_z
]
x
_prev
=
outputs
[
"final_atom_positions"
]
z
_prev
=
z
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
# [*, N, 3]
x_prev
=
outputs
[
"final_atom_positions"
]
if
no_iter
!=
3
:
return
None
,
m_1_prev
,
z_prev
,
x_prev
else
:
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
def
_disable_activation_checkpointing
(
self
):
def
_disable_activation_checkpointing
(
self
):
self
.
template_embedder
.
template_pair_stack
.
blocks_per_ckpt
=
None
self
.
template_embedder
.
template_pair_stack
.
blocks_per_ckpt
=
None
...
@@ -482,6 +505,7 @@ class AlphaFold(nn.Module):
...
@@ -482,6 +505,7 @@ class AlphaFold(nn.Module):
"""
"""
# Initialize recycling embeddings
# Initialize recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
# Disable activation checkpointing for the first few recycling iters
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled
=
torch
.
is_grad_enabled
()
is_grad_enabled
=
torch
.
is_grad_enabled
()
...
@@ -506,11 +530,14 @@ class AlphaFold(nn.Module):
...
@@ -506,11 +530,14 @@ class AlphaFold(nn.Module):
# Run the next iteration of the model
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
feats
,
m_1_prev
,
prevs
,
z_prev
,
_recycle
=
(
num_iters
>
1
),
x_prev
,
no_iter
=
cycle_no
_recycle
=
(
num_iters
>
1
)
)
)
if
cycle_no
!=
3
:
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
del
m_1_prev
,
z_prev
,
x_prev
# Run auxiliary heads
# Run auxiliary heads
outputs
.
update
(
self
.
aux_heads
(
outputs
))
outputs
.
update
(
self
.
aux_heads
(
outputs
))
...
...
fastfold/model/nn/structure_module.py
View file @
e5d72d41
...
@@ -389,6 +389,8 @@ class InvariantPointAttention(nn.Module):
...
@@ -389,6 +389,8 @@ class InvariantPointAttention(nn.Module):
)
)
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
+=
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
))
a
+=
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
))
torch
.
cuda
.
empty_cache
()
if
self
.
is_multimer
:
if
self
.
is_multimer
:
# [*, N_res, N_res, H, P_q, 3]
# [*, N_res, N_res, H, P_q, 3]
...
@@ -396,11 +398,27 @@ class InvariantPointAttention(nn.Module):
...
@@ -396,11 +398,27 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H, P_q]
# [*, N_res, N_res, H, P_q]
pt_att
=
sum
([
c
**
2
for
c
in
pt_att
])
pt_att
=
sum
([
c
**
2
for
c
in
pt_att
])
else
:
else
:
# [*, N_res, N_res, H, P_q, 3]
# # [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
# pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att
=
pt_att
**
2
# pt_att = pt_att**2
# [*, N_res, N_res, H, P_q]
# # [*, N_res, N_res, H, P_q]
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
# pt_att = sum(torch.unbind(pt_att, dim=-1))
_chunks
=
10
_ks
=
k_pts
.
unsqueeze
(
-
5
)
_qlist
=
torch
.
chunk
(
q_pts
.
unsqueeze
(
-
4
),
chunks
=
_chunks
,
dim
=
0
)
pt_att
=
None
for
_i
in
range
(
_chunks
):
_pt
=
_qlist
[
_i
]
-
_ks
_pt
=
_pt
**
2
_pt
=
sum
(
torch
.
unbind
(
_pt
,
dim
=-
1
))
if
_i
==
0
:
pt_att
=
_pt
else
:
pt_att
=
torch
.
cat
([
pt_att
,
_pt
],
dim
=
0
)
torch
.
cuda
.
empty_cache
()
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
...
@@ -430,6 +448,8 @@ class InvariantPointAttention(nn.Module):
...
@@ -430,6 +448,8 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden]
# [*, N_res, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
o
=
flatten_final_dims
(
o
,
2
)
torch
.
cuda
.
empty_cache
()
# As DeepMind explains, this manual matmul ensures that the operation
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# happens in float32.
...
@@ -447,14 +467,34 @@ class InvariantPointAttention(nn.Module):
...
@@ -447,14 +467,34 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v]
# [*, N_res, H * P_v]
o_pt_norm
=
o_pt
.
norm
(
self
.
eps
)
o_pt_norm
=
o_pt
.
norm
(
self
.
eps
)
else
:
else
:
# [*, H, 3, N_res, P_v]
# chunk permuted_pts
o_pt
=
torch
.
sum
(
permuted_pts
=
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
(
a
[...,
None
,
:,
:,
None
]
size0
=
permuted_pts
.
size
()[
0
]
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
a_size0
=
a
.
size
()[
0
]
),
if
size0
==
1
or
size0
!=
a_size0
:
dim
=-
2
,
# # [*, H, 3, N_res, P_v]
)
o_pt
=
torch
.
sum
(
(
a
[...,
None
,
:,
:,
None
]
*
permuted_pts
),
dim
=-
2
,
)
else
:
a_lists
=
torch
.
chunk
(
a
[...,
None
,
:,
:,
None
],
size0
,
dim
=
0
)
permuted_pts_lists
=
torch
.
chunk
(
permuted_pts
,
size0
,
dim
=
0
)
_c
=
None
for
i
in
range
(
size0
):
_d
=
a_lists
[
i
]
*
permuted_pts_lists
[
i
]
_d
=
torch
.
sum
(
_d
[...,
None
,
:,
:],
dim
=-
2
)
if
i
==
0
:
_c
=
_d
else
:
_c
=
torch
.
cat
([
_c
,
_d
],
dim
=
0
)
o_pt
=
torch
.
sum
(
_c
,
dim
=-
2
)
del
permuted_pts
# [*, N_res, H, P_v, 3]
# [*, N_res, H, P_v, 3]
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
...
@@ -471,8 +511,11 @@ class InvariantPointAttention(nn.Module):
...
@@ -471,8 +511,11 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, C_z]
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
))
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
))
torch
.
cuda
.
empty_cache
()
# [*, N_res, H * C_z]
# [*, N_res, H * C_z]
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
del
a
# [*, N_res, C_s]
# [*, N_res, C_s]
if
self
.
is_multimer
:
if
self
.
is_multimer
:
...
@@ -485,6 +528,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -485,6 +528,7 @@ class InvariantPointAttention(nn.Module):
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
),
o_pt_norm
,
o_pair
),
dim
=-
1
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
),
o_pt_norm
,
o_pair
),
dim
=-
1
).
to
(
dtype
=
z
.
dtype
)
).
to
(
dtype
=
z
.
dtype
)
)
)
torch
.
cuda
.
empty_cache
()
return
s
return
s
...
@@ -720,10 +764,16 @@ class StructureModule(nn.Module):
...
@@ -720,10 +764,16 @@ class StructureModule(nn.Module):
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
self
.
layer_norm_z
(
z
)
z
=
self
.
layer_norm_z
(
z
)
# inplace z
# z[0] = z[0].contiguous()
# torch.cuda.emtpy_cache()
# z[0] = self.layer_norm_z(z[0])
# [*, N, C_s]
# [*, N, C_s]
s_initial
=
s
s_initial
=
s
s
=
self
.
linear_in
(
s
)
s
=
self
.
linear_in
(
s
)
torch
.
cuda
.
empty_cache
()
# [*, N]
# [*, N]
rigids
=
Rigid
.
identity
(
rigids
=
Rigid
.
identity
(
...
@@ -737,9 +787,12 @@ class StructureModule(nn.Module):
...
@@ -737,9 +787,12 @@ class StructureModule(nn.Module):
for
i
in
range
(
self
.
no_blocks
):
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
del
z
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
ipa_dropout
(
s
)
torch
.
cuda
.
empty_cache
()
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
s
=
self
.
transition
(
s
)
torch
.
cuda
.
empty_cache
()
# [*, N]
# [*, N]
rigids
=
rigids
.
compose_q_update_vec
(
self
.
bb_update
(
s
))
rigids
=
rigids
.
compose_q_update_vec
(
self
.
bb_update
(
s
))
...
...
inference.py
View file @
e5d72d41
...
@@ -434,21 +434,21 @@ def inference_monomer_model(args):
...
@@ -434,21 +434,21 @@ def inference_monomer_model(args):
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
amber_relaxer
=
relax
.
AmberRelaxation
(
#
amber_relaxer = relax.AmberRelaxation(
use_gpu
=
True
,
#
use_gpu=True,
**
config
.
relax
,
#
**config.relax,
)
#
)
# Relax the prediction.
#
#
Relax the prediction.
t
=
time
.
perf_counter
()
#
t = time.perf_counter()
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
#
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
#
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
#
#
Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
#
relaxed_output_path = os.path.join(args.output_dir,
f
'
{
tag
}
_
{
args
.
model_name
}
_relaxed.pdb'
)
#
f'{tag}_{args.model_name}_relaxed.pdb')
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
#
with open(relaxed_output_path, 'w') as f:
f
.
write
(
relaxed_pdb_str
)
#
f.write(relaxed_pdb_str)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
inference.sh
View file @
e5d72d41
...
@@ -14,4 +14,6 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files \
...
@@ -14,4 +14,6 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files \
--jackhmmer_binary_path
`
which jackhmmer
`
\
--jackhmmer_binary_path
`
which jackhmmer
`
\
--hhblits_binary_path
`
which hhblits
`
\
--hhblits_binary_path
`
which hhblits
`
\
--hhsearch_binary_path
`
which hhsearch
`
\
--hhsearch_binary_path
`
which hhsearch
`
\
--kalign_binary_path
`
which kalign
`
--kalign_binary_path
`
which kalign
`
\
--chunk_size
4
\
--inplace
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment